Merge pull request #2856 from reuben/training-install

Package training code to avoid sys.path hacks
This commit is contained in:
Reuben Morais 2020-03-31 15:42:42 +02:00 committed by GitHub
commit 83d22e591b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
78 changed files with 3276 additions and 2699 deletions

4
.isort.cfg Normal file
View File

@ -0,0 +1,4 @@
[settings]
line_length=80
multi_line_output=3
default_section=FIRSTPARTY

View File

@ -9,7 +9,7 @@ python:
jobs: jobs:
include: include:
- stage: cardboard linter - name: cardboard linter
install: install:
- pip install --upgrade cardboardlint pylint - pip install --upgrade cardboardlint pylint
script: script:
@ -17,9 +17,10 @@ jobs:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
cardboardlinter --refspec $TRAVIS_BRANCH -n auto; cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
fi fi
- stage: python unit tests - name: python unit tests
install: install:
- pip install --upgrade -r requirements_tests.txt - pip install --upgrade -r requirements_tests.txt;
pip install --upgrade .
script: script:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
python -m unittest; python -m unittest;

View File

@ -2,934 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version, ExceptionBox
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
else:
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
from tensorflow.python.framework.ops import Tensor, Operation
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__': if __name__ == '__main__':
create_flags() try:
absl.app.run(main) from deepspeech_training import train as ds_train
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_train.run_script()

View File

@ -150,7 +150,7 @@ COPY . /DeepSpeech/
WORKDIR /DeepSpeech WORKDIR /DeepSpeech
RUN pip3 --no-cache-dir install -r requirements.txt RUN pip3 --no-cache-dir install .
# Link DeepSpeech native_client libs to tf folder # Link DeepSpeech native_client libs to tf folder
RUN ln -s /DeepSpeech/native_client /tensorflow RUN ln -s /DeepSpeech/native_client /tensorflow

View File

@ -1,53 +1,69 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
Use "python3 build_sdb.py -h" for help Use "python3 build_sdb.py -h" for help
''' """
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import argparse import argparse
import progressbar import progressbar
from util.downloader import SIMPLE_BAR from deepspeech_training.util.audio import (
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS AUDIO_TYPE_OPUS,
from util.sample_collections import samples_from_files, DirectSDBWriter AUDIO_TYPE_WAV,
change_audio_types,
)
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.sample_collections import (
DirectSDBWriter,
samples_from_files,
)
AUDIO_TYPE_LOOKUP = { AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
'wav': AUDIO_TYPE_WAV,
'opus': AUDIO_TYPE_OPUS
}
def build_sdb(): def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer: with DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)): for sample in bar(
change_audio_types(
samples, audio_type=audio_type, processes=CLI_ARGS.workers
)
):
sdb_writer.add(sample) sdb_writer.add(sample)
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) ' parser = argparse.ArgumentParser(
'from DeepSpeech CSV files and other SDB files') description="Tool for building Sample Databases (SDB files) "
parser.add_argument('sources', nargs='+', "from DeepSpeech CSV files and other SDB files"
help='Source CSV and/or SDB files - ' )
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples ' parser.add_argument(
'already ordered from shortest to longest.') "sources",
parser.add_argument('target', help='SDB file to create') nargs="+",
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(), help="Source CSV and/or SDB files - "
help='Audio representation inside target SDB') "Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
parser.add_argument('--workers', type=int, default=None, "already ordered from shortest to longest.",
help='Number of encoding SDB workers') )
parser.add_argument('--unlabeled', action='store_true', parser.add_argument("target", help="SDB file to create")
help='If to build an SDB with unlabeled (audio only) samples - ' parser.add_argument(
'typically used for building noise augmentation corpora') "--audio-type",
default="opus",
choices=AUDIO_TYPE_LOOKUP.keys(),
help="Audio representation inside target SDB",
)
parser.add_argument(
"--workers", type=int, default=None, help="Number of encoding SDB workers"
)
parser.add_argument(
"--unlabeled",
action="store_true",
help="If to build an SDB with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora",
)
return parser.parse_args() return parser.parse_args()

View File

@ -1,11 +0,0 @@
#!/usr/bin/env python
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsage
gu = GPUUsage()
gu.start()

View File

@ -1,10 +0,0 @@
#!/usr/bin/env python
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsageChart
GPUUsageChart(sys.argv[1], sys.argv[2])

View File

@ -1,20 +1,21 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import tensorflow.compat.v1 as tfv1
import sys import sys
import tensorflow.compat.v1 as tfv1
from google.protobuf import text_format from google.protobuf import text_format
def main(): def main():
# Load and export as string # Load and export as string
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin: with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef() graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read()) graph_def.ParseFromString(fin.read())
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout: with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
fout.write(text_format.MessageToString(graph_def)) fout.write(text_format.MessageToString(graph_def))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,23 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import pandas import os
import tarfile import tarfile
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -25,9 +19,9 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'aidatatang_200zh') main_folder = os.path.join(target_dir, "aidatatang_200zh")
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')): for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
extract(targz, os.path.dirname(targz)) extract(targz, os.path.dirname(targz))
# Folder structure is now: # Folder structure is now:
@ -46,9 +40,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but # Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript # only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt') transcripts_path = os.path.join(
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
)
with open(transcripts_path) as fin: with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin)) transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path): def load_set(glob_path):
set_files = [] set_files = []
@ -57,33 +53,39 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0] transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n') transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
print('Loading {} set samples...'.format(subset)) print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav')) subset_files = load_set(
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
)
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s by removing the last couple hundred samples # Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train': if subset == "train":
durations = (df['wav_filesize'] - 44) / 16000 / 2 durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0] df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset)) dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv)) print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False) df.to_csv(dest_csv, index=False)
def main(): def main():
# https://www.openslr.org/62/ # https://www.openslr.org/62/
parser = get_importers_parser(description='Import aidatatang_200zh corpus') parser = get_importers_parser(description="Import aidatatang_200zh corpus")
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz') parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -1,23 +1,17 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import os
import tarfile import tarfile
import pandas import pandas
from deepspeech_training.util.importers import get_importers_parser
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript'] COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -25,10 +19,10 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'data_aishell') main_folder = os.path.join(target_dir, "data_aishell")
wav_archives_folder = os.path.join(main_folder, 'wav') wav_archives_folder = os.path.join(main_folder, "wav")
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')): for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
extract(targz, main_folder) extract(targz, main_folder)
# Folder structure is now: # Folder structure is now:
@ -45,9 +39,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but # Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript # only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt') transcripts_path = os.path.join(
main_folder, "transcript", "aishell_transcript_v0.8.txt"
)
with open(transcripts_path) as fin: with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin)) transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path): def load_set(glob_path):
set_files = [] set_files = []
@ -56,33 +52,37 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0] transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n') transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
print('Loading {} set samples...'.format(subset)) print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav')) subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES) df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
# Trim train set to under 10s by removing the last couple hundred samples # Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train': if subset == "train":
durations = (df['wav_filesize'] - 44) / 16000 / 2 durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0] df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset)) dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv)) print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False) df.to_csv(dest_csv, index=False)
def main(): def main():
# http://www.openslr.org/33/ # http://www.openslr.org/33/
parser = get_importers_parser(description='Import AISHELL corpus') parser = get_importers_parser(description="Import AISHELL corpus")
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz') parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -1,34 +1,35 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import sox import os
import tarfile
import subprocess import subprocess
import progressbar import tarfile
from glob import glob from glob import glob
from os import path
from multiprocessing import Pool from multiprocessing import Pool
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] import progressbar
import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
print_import_report,
)
from deepspeech_training.util.importers import validate_label_eng as validate_label
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
ARCHIVE_DIR_NAME = 'cv_corpus_v1' ARCHIVE_DIR_NAME = "cv_corpus_v1"
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz' ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME ARCHIVE_URL = (
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
)
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data # Conditionally extract common voice data
@ -36,56 +37,70 @@ def _download_and_preprocess_data(target_dir):
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav # Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
else: else:
print('Found directory "%s" - not extracting it from archive.' % extracted_path) print('Found directory "%s" - not extracting it from archive.' % extracted_path)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
for source_csv in glob(path.join(extracted_dir, '*.csv')): for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1])) _maybe_convert_set(
extracted_dir,
source_csv,
os.path.join(target_dir, os.path.split(source_csv)[-1]),
)
def one_sample(sample): def one_sample(sample):
mp3_filename = sample[0] mp3_filename = sample[0]
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav" wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
)
file_size = -1 file_size = -1
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = validate_label(sample[1]) label = validate_label(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv): def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print() print()
if path.exists(target_csv): if os.path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv)) print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv)) print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
@ -93,14 +108,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
with open(source_csv) as source_csv_file: with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file) reader = csv.DictReader(source_csv_file)
for row in reader: for row in reader:
samples.append((os.path.join(extracted_dir, row['filename']), row['text'])) samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
# Mutable counters for the concurrent embedded routine # Mutable counters for the concurrent embedded routine
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
rows = [] rows = []
print('Importing mp3 files...') print("Importing mp3 files...")
pool = Pool() pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
@ -112,21 +127,28 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
pool.join() pool.join()
print('Writing "%s"...' % target_csv) print('Writing "%s"...' % target_csv)
with open(target_csv, 'w') as target_csv_file: with open(target_csv, "w") as target_csv_file:
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript }) writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:
@ -134,5 +156,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
except sox.core.SoxError: except sox.core.SoxError:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,90 +1,96 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Broadly speaking, this script takes the audio downloaded from Common Voice Broadly speaking, this script takes the audio downloaded from Common Voice
for a certain language, in addition to the *.tsv files output by CorporaCreator, for a certain language, in addition to the *.tsv files output by CorporaCreator,
and the script formats the data and transcripts to be in a state usable by and the script formats the data and transcripts to be in a state usable by
DeepSpeech.py DeepSpeech.py
Use "python3 import_cv2.py -h" for help Use "python3 import_cv2.py -h" for help
''' """
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import sox import os
import subprocess import subprocess
import progressbar
import unicodedata import unicodedata
from os import path
from multiprocessing import Pool from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']: for dataset in ["train", "test", "dev", "validated", "other"]:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv") input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
if os.path.isfile(input_tsv): if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv) print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character) _maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0] mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3': if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
mp3_filename += ".mp3" mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix # Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav" wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename) _maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter_fun(sample[1]) label = label_filter_fun(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label)) rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None): def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv')) output_csv = os.path.join(
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
)
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file: with open(input_tsv, encoding="utf-8") as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t') reader = csv.DictReader(input_tsv_file, delimiter="\t")
for row in reader: for row in reader:
samples.append((path.join(audio_dir, row['path']), row['sentence'])) samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
@ -101,26 +107,38 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
pool.close() pool.close()
pool.join() pool.join()
with open(output_csv, 'w', encoding='utf-8') as output_csv_file: with open(output_csv, "w", encoding="utf-8") as output_csv_file:
print('Writing CSV file for DeepSpeech.py as: ', output_csv) print("Writing CSV file for DeepSpeech.py as: ", output_csv)
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows): for filename, file_size, transcript in bar(rows):
if space_after_every_character: if space_after_every_character:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)}) writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": " ".join(transcript),
}
)
else: else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript}) writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:
@ -130,24 +148,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__": if __name__ == "__main__":
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora') PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
PARSER.add_argument('tsv_dir', help='Directory containing tsv files') PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"') PARSER.add_argument(
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') "--audio_dir",
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space') )
PARSER.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
PARSER.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
PARSER.add_argument(
"--space_after_every_character",
action="store_true",
help="To help transcript join by white space",
)
PARAMS = PARSER.parse_args() PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS) validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips') AUDIO_DIR = (
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
)
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
def label_filter_fun(label): def label_filter_fun(label):
if PARAMS.normalize: if PARAMS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:

View File

@ -1,25 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function import codecs
import fnmatch
import os
import subprocess
import sys
import unicodedata
import librosa
import pandas
import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import validate_label_eng as validate_label
# Prerequisite: Having the sph2pipe tool in your PATH: # Prerequisite: Having the sph2pipe tool in your PATH:
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools # https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import os
import pandas
import subprocess
import unicodedata
import librosa
import soundfile # <= Has an external dependency on libsndfile
from util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
@ -29,33 +24,55 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav") _maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
# Conditionally split Fisher wav data # Conditionally split Fisher wav data
all_2004 = _split_wav_and_sentences(data_dir, all_2004 = _split_wav_and_sentences(
original_data="fisher-2004-wav", data_dir,
converted_data="fisher-2004-split-wav", original_data="fisher-2004-wav",
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans")) converted_data="fisher-2004-split-wav",
all_2005 = _split_wav_and_sentences(data_dir, trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
original_data="fisher-2005-wav", )
converted_data="fisher-2005-split-wav", all_2005 = _split_wav_and_sentences(
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans")) data_dir,
original_data="fisher-2005-wav",
converted_data="fisher-2005-split-wav",
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
)
# The following files have incorrect transcripts that are much longer than # The following files have incorrect transcripts that are much longer than
# their audio source. The result is that we end up with more labels than time # their audio source. The result is that we end up with more labels than time
# slices, which breaks CTC. # slices, which breaks CTC.
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct" all_2004.loc[
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those" all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want" "transcript",
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do" ] = "correct"
all_2004.loc[
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
"transcript",
] = "that's one of those"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
"transcript",
] = "they don't want"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
"transcript",
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
# The following file is just a short sound and not at all transcribed like provided. # The following file is just a short sound and not at all transcribed like provided.
# So we just exclude it. # So we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")] all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
]
# The following file is far too long and would ruin our training batch size. # The following file is far too long and would ruin our training batch size.
# So we just exclude it. # So we just exclude it.
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")] all_2005 = all_2005[
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
]
# The following file is too large for its transcript, so we just exclude it. # The following file is too large for its transcript, so we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")] all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
]
# Conditionally split Fisher data into train/validation/test sets # Conditionally split Fisher data into train/validation/test sets
train_2004, dev_2004, test_2004 = _split_sets(all_2004) train_2004, dev_2004, test_2004 = _split_sets(all_2004)
@ -71,6 +88,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False) dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False) test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
def _maybe_convert_wav(data_dir, original_data, converted_data): def _maybe_convert_wav(data_dir, original_data, converted_data):
source_dir = os.path.join(data_dir, original_data) source_dir = os.path.join(data_dir, original_data)
target_dir = os.path.join(data_dir, converted_data) target_dir = os.path.join(data_dir, converted_data)
@ -88,10 +106,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
for filename in fnmatch.filter(filenames, "*.sph"): for filename in fnmatch.filter(filenames, "*.sph"):
sph_file = os.path.join(root, filename) sph_file = os.path.join(root, filename)
for channel in ["1", "2"]: for channel in ["1", "2"]:
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav" wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "_c"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename) wav_file = os.path.join(target_dir, wav_filename)
print("converting {} to {}".format(sph_file, wav_file)) print("converting {} to {}".format(sph_file, wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]) subprocess.check_call(
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
)
def _parse_transcriptions(trans_file): def _parse_transcriptions(trans_file):
segments = [] segments = []
@ -109,18 +135,23 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript)
.decode("ascii", "ignore") .encode("ascii", "ignore")
.decode("ascii", "ignore")
)
segments.append({ segments.append(
"start_time": start_time, {
"stop_time": stop_time, "start_time": start_time,
"speaker": speaker, "stop_time": stop_time,
"transcript": transcript, "speaker": speaker,
}) "transcript": transcript,
}
)
return segments return segments
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data): def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
trans_dir = os.path.join(data_dir, trans_data) trans_dir = os.path.join(data_dir, trans_data)
source_dir = os.path.join(data_dir, original_data) source_dir = os.path.join(data_dir, original_data)
@ -137,43 +168,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
segments = _parse_transcriptions(trans_file) segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file # Open wav corresponding to transcription file
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]] wav_filenames = [
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames] os.path.splitext(os.path.basename(trans_file))[0]
+ "_c"
+ channel
+ ".wav"
for channel in ["1", "2"]
]
wav_files = [
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
]
print("splitting {} according to {}".format(wav_files, trans_file)) print("splitting {} according to {}".format(wav_files, trans_file))
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files] origAudios = [
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
]
# Loop over segments and split wav_file for each segment # Loop over segments and split wav_file for each segment
for segment in segments: for segment in segments:
# Create wav segment filename # Create wav segment filename
start_time = segment["start_time"] start_time = segment["start_time"]
stop_time = segment["stop_time"] stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav" new_wav_filename = (
os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = os.path.join(target_dir, new_wav_filename) new_wav_file = os.path.join(target_dir, new_wav_filename)
channel = 0 if segment["speaker"] == "A:" else 1 channel = 0 if segment["speaker"] == "A:" else 1
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file) _split_and_resample_wav(
origAudios[channel], start_time, stop_time, new_wav_file
)
new_wav_filesize = os.path.getsize(new_wav_file) new_wav_filesize = os.path.getsize(new_wav_file)
transcript = validate_label(segment["transcript"]) transcript = validate_label(segment["transcript"])
if transcript != None: if transcript != None:
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript)) files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
def _split_audio(origAudio, start_time, stop_time): def _split_audio(origAudio, start_time, stop_time):
audioData, frameRate = origAudio audioData, frameRate = origAudio
nChannels = len(audioData.shape) nChannels = len(audioData.shape)
startIndex = int(start_time * frameRate) startIndex = int(start_time * frameRate)
stopIndex = int(stop_time * frameRate) stopIndex = int(stop_time * frameRate)
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex] return (
audioData[startIndex:stopIndex]
if 1 == nChannels
else audioData[:, startIndex:stopIndex]
)
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file): def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio[1] frameRate = origAudio[1]
chunkData = _split_audio(origAudio, start_time, stop_time) chunkData = _split_audio(origAudio, start_time, stop_time)
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16") soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
def _split_sets(filelist): def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then # We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation. # split the train set into 80% train and 20% validation.
@ -187,9 +248,12 @@ def _split_sets(filelist):
test_beg = dev_end test_beg = dev_end
test_end = len(filelist) test_end = len(filelist)
return (filelist[train_beg:train_end], return (
filelist[dev_beg:dev_end], filelist[train_beg:train_end],
filelist[test_beg:test_end]) filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import numpy as np import os
import pandas
import tarfile import tarfile
import numpy as np
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -26,7 +20,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS') main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
# Folder structure is now: # Folder structure is now:
# - ST-CMDS-20170001_1-OS/ # - ST-CMDS-20170001_1-OS/
@ -39,16 +33,16 @@ def preprocess_data(tgz_file, target_dir):
for wav in glob.glob(glob_path): for wav in glob.glob(glob_path):
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
txt_filename = os.path.splitext(wav_filename)[0] + '.txt' txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
with open(txt_filename, 'r') as fin: with open(txt_filename, "r") as fin:
transcript = fin.read() transcript = fin.read()
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
return set_files return set_files
# Load all files, then deterministically split into train/dev/test sets # Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, '*.wav')) all_files = load_set(os.path.join(main_folder, "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True) df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df)) indices = np.arange(0, len(df))
np.random.seed(12345) np.random.seed(12345)
@ -61,29 +55,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000] train_indices = indices[:-10000]
train_files = df.iloc[train_indices] train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2 durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 10.0] train_files = train_files[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv') dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
print('Saving train set into {}...'.format(dest_csv)) print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False) train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices] dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv') dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
print('Saving dev set into {}...'.format(dest_csv)) print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False) dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices] test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv') dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
print('Saving test set into {}...'.format(dest_csv)) print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False) test_files.to_csv(dest_csv, index=False)
def main(): def main():
# https://www.openslr.org/38/ # https://www.openslr.org/38/
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus') parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz') parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv import csv
import math
import urllib
import logging import logging
from util.importers import get_importers_parser, get_validate_label import math
import os
import subprocess import subprocess
from os import path import urllib
from pathlib import Path from pathlib import Path
import swifter
import pandas as pd import pandas as pd
from sox import Transformer from sox import Transformer
import swifter
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
__version__ = "0.1.0" __version__ = "0.1.0"
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -40,9 +34,7 @@ def parse_args(args):
Returns: Returns:
:obj:`argparse.Namespace`: command line parameters namespace :obj:`argparse.Namespace`: command line parameters namespace
""" """
parser = get_importers_parser( parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
description="Imports GramVaani data for Deep Speech"
)
parser.add_argument( parser.add_argument(
"--version", "--version",
action="version", action="version",
@ -82,6 +74,7 @@ def parse_args(args):
) )
return parser.parse_args(args) return parser.parse_args(args)
def setup_logging(level): def setup_logging(level):
"""Setup basic logging """Setup basic logging
Args: Args:
@ -92,6 +85,7 @@ def setup_logging(level):
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S" level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
) )
class GramVaaniCSV: class GramVaaniCSV:
"""GramVaaniCSV representing a GramVaani dataset. """GramVaaniCSV representing a GramVaani dataset.
Args: Args:
@ -107,8 +101,17 @@ class GramVaaniCSV:
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename)) _logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
data = pd.read_csv( data = pd.read_csv(
os.path.abspath(csv_filename), os.path.abspath(csv_filename),
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"], names=[
usecols=["audio_url","transcript","audio_length"], "piece_id",
"audio_url",
"transcript_labelled",
"transcript",
"labels",
"content_filename",
"audio_length",
"user_id",
],
usecols=["audio_url", "transcript", "audio_length"],
skiprows=[0], skiprows=[0],
engine="python", engine="python",
encoding="utf-8", encoding="utf-8",
@ -119,6 +122,7 @@ class GramVaaniCSV:
_logger.info("Parsed %d lines csv file." % len(data)) _logger.info("Parsed %d lines csv file." % len(data))
return data return data
class GramVaaniDownloader: class GramVaaniDownloader:
"""GramVaaniDownloader downloads a GramVaani dataset. """GramVaaniDownloader downloads a GramVaani dataset.
Args: Args:
@ -138,15 +142,17 @@ class GramVaaniDownloader:
mp3_directory (os.path): The directory into which the associated mp3's were downloaded mp3_directory (os.path): The directory into which the associated mp3's were downloaded
""" """
mp3_directory = self._pre_download() mp3_directory = self._pre_download()
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True) self.data.swifter.apply(
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
)
return mp3_directory return mp3_directory
def _pre_download(self): def _pre_download(self):
mp3_directory = path.join(self.target_dir, "mp3") mp3_directory = os.path.join(self.target_dir, "mp3")
if not path.exists(self.target_dir): if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir) _logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir) os.mkdir(self.target_dir)
if not path.exists(mp3_directory): if not os.path.exists(mp3_directory):
_logger.info("Creating directory...%s", mp3_directory) _logger.info("Creating directory...%s", mp3_directory)
os.mkdir(mp3_directory) os.mkdir(mp3_directory)
return mp3_directory return mp3_directory
@ -154,13 +160,14 @@ class GramVaaniDownloader:
def _download(self, audio_url, transcript, audio_length, mp3_directory): def _download(self, audio_url, transcript, audio_length, mp3_directory):
if audio_url == "audio_url": if audio_url == "audio_url":
return return
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url)) mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
if not path.exists(mp3_filename): if not os.path.exists(mp3_filename):
_logger.debug("Downloading mp3 file...%s", audio_url) _logger.debug("Downloading mp3 file...%s", audio_url)
urllib.request.urlretrieve(audio_url, mp3_filename) urllib.request.urlretrieve(audio_url, mp3_filename)
else: else:
_logger.debug("Already downloaded mp3 file...%s", audio_url) _logger.debug("Already downloaded mp3 file...%s", audio_url)
class GramVaaniConverter: class GramVaaniConverter:
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset. """GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
Args: Args:
@ -181,37 +188,53 @@ class GramVaaniConverter:
wav_directory (os.path): The directory into which the associated wav's were downloaded wav_directory (os.path): The directory into which the associated wav's were downloaded
""" """
wav_directory = self._pre_convert() wav_directory = self._pre_convert()
for mp3_filename in self.mp3_directory.glob('**/*.mp3'): for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav") wav_filename = os.path.join(
if not path.exists(wav_filename): wav_directory,
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename)) os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
)
if not os.path.exists(wav_filename):
_logger.debug(
"Converting mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
transformer = Transformer() transformer = Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH) transformer.convert(
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
)
transformer.build(str(mp3_filename), str(wav_filename)) transformer.build(str(mp3_filename), str(wav_filename))
else: else:
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename)) _logger.debug(
"Already converted mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
return wav_directory return wav_directory
def _pre_convert(self): def _pre_convert(self):
wav_directory = path.join(self.target_dir, "wav") wav_directory = os.path.join(self.target_dir, "wav")
if not path.exists(self.target_dir): if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir) _logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir) os.mkdir(self.target_dir)
if not path.exists(wav_directory): if not os.path.exists(wav_directory):
_logger.info("Creating directory...%s", wav_directory) _logger.info("Creating directory...%s", wav_directory)
os.mkdir(wav_directory) os.mkdir(wav_directory)
return wav_directory return wav_directory
class GramVaaniDataSets: class GramVaaniDataSets:
def __init__(self, target_dir, wav_directory, gram_vaani_csv): def __init__(self, target_dir, wav_directory, gram_vaani_csv):
self.target_dir = target_dir self.target_dir = target_dir
self.wav_directory = wav_directory self.wav_directory = wav_directory
self.csv_data = gram_vaani_csv.data self.csv_data = gram_vaani_csv.data
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) self.valid = pd.DataFrame(
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) columns=["wav_filename", "wav_filesize", "transcript"]
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) )
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"]) self.train = pd.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]
)
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
def create(self): def create(self):
self._convert_csv_data_to_raw_data() self._convert_csv_data_to_raw_data()
@ -220,21 +243,31 @@ class GramVaaniDataSets:
self.valid = self.valid.sample(frac=1).reset_index(drop=True) self.valid = self.valid.sample(frac=1).reset_index(drop=True)
train_size, dev_size, test_size = self._calculate_data_set_sizes() train_size, dev_size, test_size = self._calculate_data_set_sizes()
self.train = self.valid.loc[0:train_size] self.train = self.valid.loc[0:train_size]
self.dev = self.valid.loc[train_size:train_size+dev_size] self.dev = self.valid.loc[train_size : train_size + dev_size]
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size] self.test = self.valid.loc[
train_size + dev_size : train_size + dev_size + test_size
]
def _convert_csv_data_to_raw_data(self): def _convert_csv_data_to_raw_data(self):
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[ self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
["audio_url","transcript","audio_length"] ["audio_url", "transcript", "audio_length"]
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True) ].swifter.apply(
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
axis=1,
raw=True,
)
self.raw.reset_index() self.raw.reset_index()
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length): def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
if audio_url == "audio_url": if audio_url == "audio_url":
return pd.Series(["wav_filename", "wav_filesize", "transcript"]) return pd.Series(["wav_filename", "wav_filesize", "transcript"])
mp3_filename = os.path.basename(audio_url) mp3_filename = os.path.basename(audio_url)
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav") wav_relative_filename = os.path.join(
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename)) "wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
)
wav_filesize = os.path.getsize(
os.path.join(self.target_dir, wav_relative_filename)
)
transcript = validate_label(transcript) transcript = validate_label(transcript)
if None == transcript: if None == transcript:
transcript = "" transcript = ""
@ -243,7 +276,12 @@ class GramVaaniDataSets:
def _is_valid_raw_rows(self): def _is_valid_raw_rows(self):
is_valid_raw_transcripts = self._is_valid_raw_transcripts() is_valid_raw_transcripts = self._is_valid_raw_transcripts()
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames() is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)] is_valid_raw_row = [
(is_valid_raw_transcript & is_valid_raw_wav_frame)
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
is_valid_raw_transcripts, is_valid_raw_wav_frames
)
]
series = pd.Series(is_valid_raw_row) series = pd.Series(is_valid_raw_row)
return series return series
@ -252,16 +290,29 @@ class GramVaaniDataSets:
def _is_valid_raw_wav_frames(self): def _is_valid_raw_wav_frames(self):
transcripts = [str(transcript) for transcript in self.raw.transcript] transcripts = [str(transcript) for transcript in self.raw.transcript]
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename] wav_filepaths = [
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths] os.path.join(self.target_dir, str(wav_filename))
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)] for wav_filename in self.raw.wav_filename
]
wav_frames = [
int(
subprocess.check_output(
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
)
)
for wav_filepath in wav_filepaths
]
is_valid_raw_wav_frames = [
self._is_wav_frame_valid(wav_frame, transcript)
for wav_frame, transcript in zip(wav_frames, transcripts)
]
return pd.Series(is_valid_raw_wav_frames) return pd.Series(is_valid_raw_wav_frames)
def _is_wav_frame_valid(self, wav_frame, transcript): def _is_wav_frame_valid(self, wav_frame, transcript):
is_wav_frame_valid = True is_wav_frame_valid = True
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)): if int(wav_frame / SAMPLE_RATE * 1000 / 10 / 2) < len(str(transcript)):
is_wav_frame_valid = False is_wav_frame_valid = False
elif wav_frame/SAMPLE_RATE > MAX_SECS: elif wav_frame / SAMPLE_RATE > MAX_SECS:
is_wav_frame_valid = False is_wav_frame_valid = False
return is_wav_frame_valid return is_wav_frame_valid
@ -280,7 +331,14 @@ class GramVaaniDataSets:
def _save(self, dataset): def _save(self, dataset):
dataset_path = os.path.join(self.target_dir, dataset + ".csv") dataset_path = os.path.join(self.target_dir, dataset + ".csv")
dataframe = getattr(self, dataset) dataframe = getattr(self, dataset)
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL) dataframe.to_csv(
dataset_path,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_MINIMAL,
)
def main(args): def main(args):
"""Main entry point allowing external calls """Main entry point allowing external calls
@ -304,4 +362,5 @@ def main(args):
datasets.save() datasets.save()
_logger.info("Finished GramVaani importer...") _logger.info("Finished GramVaani importer...")
main(sys.argv[1:]) main(sys.argv[1:])

View File

@ -1,28 +1,33 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os import os
sys.path.insert(1, os.path.join(sys.path[0], '..')) import sys
import pandas import pandas
from util.downloader import maybe_download from deepspeech_training.util.downloader import maybe_download
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data # Conditionally download data
LDC93S1_BASE = "LDC93S1" LDC93S1_BASE = "LDC93S1"
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/" LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav") local_file = maybe_download(
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt") LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
)
trans_file = maybe_download(
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
)
with open(trans_file, "r") as fin: with open(trans_file, "r") as fin:
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '') transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
".", ""
)
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)], df = pandas.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]) data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
columns=["wav_filename", "wav_filesize", "transcript"],
)
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False) df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,33 +1,39 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs import codecs
import fnmatch import fnmatch
import pandas import os
import progressbar
import subprocess import subprocess
import sys
import tarfile import tarfile
import unicodedata import unicodedata
import pandas
import progressbar
from sox import Transformer from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from deepspeech_training.util.downloader import maybe_download
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir # Conditionally download data to data_dir
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir)) print(
"Downloading Librivox data set (55GB) into {} if not already present...".format(
data_dir
)
)
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar: with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz" TRAIN_CLEAN_100_URL = (
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz" "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz" )
TRAIN_CLEAN_360_URL = (
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
)
TRAIN_OTHER_500_URL = (
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
)
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz" DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz" DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
@ -35,12 +41,20 @@ def _download_and_preprocess_data(data_dir):
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz" TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz" TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
def filename_of(x): return os.path.split(x)[1] def filename_of(x):
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL) return os.path.split(x)[1]
train_clean_100 = maybe_download(
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
)
bar.update(0) bar.update(0)
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL) train_clean_360 = maybe_download(
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
)
bar.update(1) bar.update(1)
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL) train_other_500 = maybe_download(
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
)
bar.update(2) bar.update(2)
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL) dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
@ -48,9 +62,13 @@ def _download_and_preprocess_data(data_dir):
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL) dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
bar.update(4) bar.update(4)
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL) test_clean = maybe_download(
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
)
bar.update(5) bar.update(5)
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL) test_other = maybe_download(
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
)
bar.update(6) bar.update(6)
# Conditionally extract LibriSpeech data # Conditionally extract LibriSpeech data
@ -61,11 +79,17 @@ def _download_and_preprocess_data(data_dir):
LIBRIVOX_DIR = "LibriSpeech" LIBRIVOX_DIR = "LibriSpeech"
work_dir = os.path.join(data_dir, LIBRIVOX_DIR) work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100) _maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
)
bar.update(0) bar.update(0)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360) _maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
)
bar.update(1) bar.update(1)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500) _maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
)
bar.update(2) bar.update(2)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean) _maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
@ -91,28 +115,48 @@ def _download_and_preprocess_data(data_dir):
# data_dir/LibriSpeech/split-wav/1-2-2.txt # data_dir/LibriSpeech/split-wav/1-2-2.txt
# ... # ...
print("Converting FLAC to WAV and splitting transcriptions...") print("Converting FLAC to WAV and splitting transcriptions...")
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar: with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav") train_100 = _convert_audio_and_split_sentences(
work_dir, "train-clean-100", "train-clean-100-wav"
)
bar.update(0) bar.update(0)
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav") train_360 = _convert_audio_and_split_sentences(
work_dir, "train-clean-360", "train-clean-360-wav"
)
bar.update(1) bar.update(1)
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav") train_500 = _convert_audio_and_split_sentences(
work_dir, "train-other-500", "train-other-500-wav"
)
bar.update(2) bar.update(2)
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav") dev_clean = _convert_audio_and_split_sentences(
work_dir, "dev-clean", "dev-clean-wav"
)
bar.update(3) bar.update(3)
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav") dev_other = _convert_audio_and_split_sentences(
work_dir, "dev-other", "dev-other-wav"
)
bar.update(4) bar.update(4)
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav") test_clean = _convert_audio_and_split_sentences(
work_dir, "test-clean", "test-clean-wav"
)
bar.update(5) bar.update(5)
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav") test_other = _convert_audio_and_split_sentences(
work_dir, "test-other", "test-other-wav"
)
bar.update(6) bar.update(6)
# Write sets to disk as CSV files # Write sets to disk as CSV files
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False) train_100.to_csv(
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False) os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False) )
train_360.to_csv(
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
)
train_500.to_csv(
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
)
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False) dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False) dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
@ -120,6 +164,7 @@ def _download_and_preprocess_data(data_dir):
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False) test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False) test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive): def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir # If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(os.path.join(data_dir, extracted_data)): if not gfile.Exists(os.path.join(data_dir, extracted_data)):
@ -127,6 +172,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir) tar.extractall(data_dir)
tar.close() tar.close()
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir): def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
source_dir = os.path.join(extracted_dir, data_set) source_dir = os.path.join(extracted_dir, data_set)
target_dir = os.path.join(extracted_dir, dest_dir) target_dir = os.path.join(extracted_dir, dest_dir)
@ -149,20 +195,22 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
# We also convert the corresponding FLACs to WAV in the same pass # We also convert the corresponding FLACs to WAV in the same pass
files = [] files = []
for root, dirnames, filenames in os.walk(source_dir): for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, '*.trans.txt'): for filename in fnmatch.filter(filenames, "*.trans.txt"):
trans_filename = os.path.join(root, filename) trans_filename = os.path.join(root, filename)
with codecs.open(trans_filename, "r", "utf-8") as fin: with codecs.open(trans_filename, "r", "utf-8") as fin:
for line in fin: for line in fin:
# Parse each segment line # Parse each segment line
first_space = line.find(" ") first_space = line.find(" ")
seqid, transcript = line[:first_space], line[first_space+1:] seqid, transcript = line[:first_space], line[first_space + 1 :]
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript)
.decode("ascii", "ignore") .encode("ascii", "ignore")
.decode("ascii", "ignore")
)
transcript = transcript.lower().strip() transcript = transcript.lower().strip()
@ -177,7 +225,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
files.append((os.path.abspath(wav_file), wav_filesize, transcript)) files.append((os.path.abspath(wav_file), wav_filesize, transcript))
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,44 +1,39 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse import argparse
import csv import csv
import os
import re import re
import sox
import zipfile
import subprocess import subprocess
import progressbar
import unicodedata import unicodedata
import zipfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from glob import glob from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download import progressbar
from util.text import Alphabet import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
ARCHIVE_DIR_NAME = 'lingua_libre' ARCHIVE_DIR_NAME = "lingua_libre"
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip' ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data # Conditionally extract data
@ -46,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files and convert ogg data to wav # Produce CSV files and convert ogg data to wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -58,57 +54,70 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0] ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix # Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav" wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename) _maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1]) label = label_filter(sample[1])
rows = [] rows = []
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv')) target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
)
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', '')) ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg') glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '') record_file = record.replace(ogg_root_dir + os.path.sep, "")
if record_filter(record_file): if record_filter(record_file):
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0])) samples.append(
(
os.path.join(ogg_root_dir, record_file),
os.path.splitext(os.path.basename(record_file))[0],
)
)
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -139,7 +148,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = validate_label(item[2]) transcript = validate_label(item[2])
if not transcript: if not transcript:
continue continue
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav')) wav_filename = os.path.join(
ogg_root_dir, item[0].replace(".ogg", ".wav")
)
i_mod = i % 10 i_mod = i % 10
if i_mod == 0: if i_mod == 0:
writer = test_writer writer = test_writer
@ -147,38 +158,63 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
wav_filename=wav_filename, dict(
wav_filesize=os.path.getsize(wav_filename), wav_filename=wav_filename,
transcript=transcript, wav_filesize=os.path.getsize(wav_filename),
)) transcript=transcript,
)
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:
transformer.build(ogg_filename, wav_filename) transformer.build(ogg_filename, wav_filename)
except sox.core.SoxError as ex: except sox.core.SoxError as ex:
print('SoX processing error', ex, ogg_filename, wav_filename) print("SoX processing error", ex, ogg_filename, wav_filename)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.') parser = get_importers_parser(
parser.add_argument(dest='target_dir') description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId') )
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code') parser.add_argument(dest="target_dir")
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language') parser.add_argument(
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') "--qId", type=int, required=True, help="LinguaLibre language qId"
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') )
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items') parser.add_argument(
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
)
parser.add_argument(
"--english-name", type=str, required=True, help="Enligh name of the language"
)
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--bogus-records",
type=argparse.FileType("r"),
required=False,
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -191,15 +227,17 @@ if __name__ == "__main__":
def record_filter(path): def record_filter(path):
if any(regex.match(path) for regex in bogus_regexes): if any(regex.match(path) for regex in bogus_regexes):
print('Reject', path) print("Reject", path)
return False return False
return True return True
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:
@ -208,6 +246,14 @@ if __name__ == "__main__":
label = None label = None
return label return label
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name) ARCHIVE_NAME = ARCHIVE_NAME.format(
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name) qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
ARCHIVE_URL = ARCHIVE_URL.format(
qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir) _download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)

View File

@ -1,43 +1,37 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# pylint: disable=invalid-name # pylint: disable=invalid-name
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import os
import subprocess import subprocess
import progressbar
import unicodedata
import tarfile import tarfile
import unicodedata
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
from glob import glob from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download import progressbar
from util.text import Alphabet
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 15 MAX_SECS = 15
ARCHIVE_DIR_NAME = '{language}' ARCHIVE_DIR_NAME = "{language}"
ARCHIVE_NAME = '{language}.tgz' ARCHIVE_NAME = "{language}.tgz"
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data # Conditionally extract data
@ -48,8 +42,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -65,9 +59,13 @@ def one_sample(sample):
wav_filename = sample[0] wav_filename = sample[0]
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1]) label = label_filter(sample[1])
counter = get_counter() counter = get_counter()
rows = [] rows = []
@ -75,27 +73,30 @@ def one_sample(sample):
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
print("conversion failure", wav_filename) print("conversion failure", wav_filename)
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv')) target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
)
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
@ -103,14 +104,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv') glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop if any(
map(lambda sk: sk in record, SKIP_LIST)
): # pylint: disable=cell-var-from-loop
continue continue
with open(record, 'r') as rec: with open(record, "r") as rec:
for re in rec.readlines(): for re in rec.readlines():
re = re.strip().split('|') re = re.strip().split("|")
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav') audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
transcript = re[2] transcript = re[2]
samples.append((audio, transcript)) samples.append((audio, transcript))
@ -129,9 +132,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -151,39 +154,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
wav_filename=os.path.relpath(wav_filename, extracted_dir), dict(
wav_filesize=os.path.getsize(wav_filename), wav_filename=os.path.relpath(wav_filename, extracted_dir),
transcript=transcript, wav_filesize=os.path.getsize(wav_filename),
)) transcript=transcript,
)
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.') parser = get_importers_parser(
parser.add_argument(dest='target_dir') description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') )
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') parser.add_argument(dest="target_dir")
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated') parser.add_argument(
parser.add_argument('--language', required=True, type=str, help='Dataset language to use') "--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--skiplist",
type=str,
default="",
help="Directories / books to skip, comma separated",
)
parser.add_argument(
"--language", required=True, type=str, help="Dataset language to use"
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(',')) SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
validate_label = get_validate_label(CLI_ARGS) validate_label = get_validate_label(CLI_ARGS)
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:

View File

@ -1,30 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import pandas import os
import tarfile import tarfile
import wave import wave
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
def is_file_truncated(wav_filename, wav_filesize): def is_file_truncated(wav_filename, wav_filesize):
with wave.open(wav_filename, mode='rb') as fin: with wave.open(wav_filename, mode="rb") as fin:
assert fin.getframerate() == 16000 assert fin.getframerate() == 16000
assert fin.getsampwidth() == 2 assert fin.getsampwidth() == 2
assert fin.getnchannels() == 1 assert fin.getnchannels() == 1
@ -37,8 +31,13 @@ def is_file_truncated(wav_filename, wav_filesize):
def preprocess_data(folder_with_archives, target_dir): def preprocess_data(folder_with_archives, target_dir):
# First extract subset archives # First extract subset archives
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir) extract(
os.path.join(
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
),
target_dir,
)
# Folder structure is now: # Folder structure is now:
# - magicdata_{train,dev,test}.tar.gz # - magicdata_{train,dev,test}.tar.gz
@ -54,58 +53,73 @@ def preprocess_data(folder_with_archives, target_dir):
# name, one containing the speaker ID, and one containing the transcription # name, one containing the speaker ID, and one containing the transcription
def load_set(set_path): def load_set(set_path):
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0) transcripts = pandas.read_csv(
glob_path = os.path.join(set_path, '*', '*.wav') os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
)
glob_path = os.path.join(set_path, "*", "*.wav")
set_files = [] set_files = []
for wav in glob.glob(glob_path): for wav in glob.glob(glob_path):
try: try:
wav_filename = wav wav_filename = wav
wav_filesize = os.path.getsize(wav) wav_filesize = os.path.getsize(wav)
transcript_key = os.path.basename(wav) transcript_key = os.path.basename(wav)
transcript = transcripts.loc[transcript_key, 'Transcription'] transcript = transcripts.loc[transcript_key, "Transcription"]
# Some files in this dataset are truncated, the header duration # Some files in this dataset are truncated, the header duration
# doesn't match the file size. This causes errors at training # doesn't match the file size. This causes errors at training
# time, so check here if things are fine before including a file # time, so check here if things are fine before including a file
if is_file_truncated(wav_filename, wav_filesize): if is_file_truncated(wav_filename, wav_filesize):
print('Warning: File {} is corrupted, header duration does ' print(
'not match file size. Ignoring.'.format(wav_filename)) "Warning: File {} is corrupted, header duration does "
"not match file size. Ignoring.".format(wav_filename)
)
continue continue
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
for subset in ('train', 'dev', 'test'): for subset in ("train", "dev", "test"):
print('Loading {} set samples...'.format(subset)) print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(target_dir, subset)) subset_files = load_set(os.path.join(target_dir, subset))
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s # Trim train set to under 10s
if subset == 'train': if subset == "train":
durations = (df['wav_filesize'] - 44) / 16000 / 2 durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0] df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum())) print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]') with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
df = df[~with_noise] df = df[~with_noise]
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise))) print(
"Trimming {} samples with noise ([FIL] or [SPK])".format(
sum(with_noise)
)
)
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset)) dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv)) print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False) df.to_csv(dest_csv, index=False)
def main(): def main():
# https://openslr.org/68/ # https://openslr.org/68/
parser = get_importers_parser(description='Import MAGICDATA corpus') parser = get_importers_parser(description="Import MAGICDATA corpus")
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz') parser.add_argument(
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives') "folder_with_archives",
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
)
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata') params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
preprocess_data(params.folder_with_archives, params.target_dir) preprocess_data(params.folder_with_archives, params.target_dir)

View File

@ -1,25 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob import glob
import json import json
import numpy as np import os
import pandas
import tarfile import tarfile
import numpy as np
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir): def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir)) print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar: with tarfile.open(archive_path) as tar:
tar.extractall(target_dir) tar.extractall(target_dir)
@ -27,7 +21,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir): def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives # First extract main archive and sub-archives
extract(tgz_file, target_dir) extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1') main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
# Folder structure is now: # Folder structure is now:
# - primewords_md_2018_set1/ # - primewords_md_2018_set1/
@ -35,14 +29,11 @@ def preprocess_data(tgz_file, target_dir):
# - [0-f]/[00-0f]/*.wav # - [0-f]/[00-0f]/*.wav
# - set1_transcript.json # - set1_transcript.json
transcripts_path = os.path.join(main_folder, 'set1_transcript.json') transcripts_path = os.path.join(main_folder, "set1_transcript.json")
with open(transcripts_path) as fin: with open(transcripts_path) as fin:
transcripts = json.load(fin) transcripts = json.load(fin)
transcripts = { transcripts = {entry["file"]: entry["text"] for entry in transcripts}
entry['file']: entry['text']
for entry in transcripts
}
def load_set(glob_path): def load_set(glob_path):
set_files = [] set_files = []
@ -54,13 +45,13 @@ def preprocess_data(tgz_file, target_dir):
transcript = transcripts[transcript_key] transcript = transcripts[transcript_key]
set_files.append((wav_filename, wav_filesize, transcript)) set_files.append((wav_filename, wav_filesize, transcript))
except KeyError: except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav)) print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files return set_files
# Load all files, then deterministically split into train/dev/test sets # Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav')) all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES) df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True) df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df)) indices = np.arange(0, len(df))
np.random.seed(12345) np.random.seed(12345)
@ -73,29 +64,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000] train_indices = indices[:-10000]
train_files = df.iloc[train_indices] train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2 durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 15.0] train_files = train_files[durations <= 15.0]
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum())) print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
dest_csv = os.path.join(target_dir, 'primewords_train.csv') dest_csv = os.path.join(target_dir, "primewords_train.csv")
print('Saving train set into {}...'.format(dest_csv)) print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False) train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices] dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'primewords_dev.csv') dest_csv = os.path.join(target_dir, "primewords_dev.csv")
print('Saving dev set into {}...'.format(dest_csv)) print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False) dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices] test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'primewords_test.csv') dest_csv = os.path.join(target_dir, "primewords_test.csv")
print('Saving test set into {}...'.format(dest_csv)) print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False) test_files.to_csv(dest_csv, index=False)
def main(): def main():
# https://www.openslr.org/47/ # https://www.openslr.org/47/
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1') parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz') parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.') parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args() params = parser.parse_args()
if not params.target_dir: if not params.target_dir:

View File

@ -1,45 +1,39 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv import csv
import os
import re import re
import sox
import zipfile
import subprocess import subprocess
import progressbar
import unicodedata
import tarfile import tarfile
import unicodedata
from multiprocessing import Pool import zipfile
from util.downloader import SIMPLE_BAR
from os import path
from glob import glob from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download import progressbar
from util.text import Alphabet import sox
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 15 MAX_SECS = 15
ARCHIVE_DIR_NAME = 'African_Accented_French' ARCHIVE_DIR_NAME = "African_Accented_French"
ARCHIVE_NAME = 'African_Accented_French.tar.gz' ARCHIVE_NAME = "African_Accented_French.tar.gz"
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data # Conditionally extract data
@ -47,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files # Produce CSV files
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME) _maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -60,81 +55,89 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else: else:
print('Found directory "%s" - not extracting it from archive.' % archive_path) print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0] wav_filename = sample[0]
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1]) label = label_filter(sample[1])
counter = get_counter() counter = get_counter()
rows = [] rows = []
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv')) target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
)
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
wav_root_dir = os.path.join(extracted_dir) wav_root_dir = os.path.join(extracted_dir)
all_files = [ all_files = [
'transcripts/train/yaounde/fn_text.txt', "transcripts/train/yaounde/fn_text.txt",
'transcripts/train/ca16_conv/transcripts.txt', "transcripts/train/ca16_conv/transcripts.txt",
'transcripts/train/ca16_read/conditioned.txt', "transcripts/train/ca16_read/conditioned.txt",
'transcripts/dev/niger_west_african_fr/transcripts.txt', "transcripts/dev/niger_west_african_fr/transcripts.txt",
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv', "speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
'transcripts/devtest/ca16_read/conditioned.txt', "transcripts/devtest/ca16_read/conditioned.txt",
'transcripts/test/ca16/prompts.txt', "transcripts/test/ca16/prompts.txt",
] ]
transcripts = {} transcripts = {}
for tr in all_files: for tr in all_files:
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source: with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
for line in tr_source.readlines(): for line in tr_source.readlines():
line = line.strip() line = line.strip()
if '.tsv' in tr: if ".tsv" in tr:
sep = ' ' sep = " "
else: else:
sep = ' ' sep = " "
audio = os.path.basename(line.split(sep)[0]) audio = os.path.basename(line.split(sep)[0])
if not ('.wav' in audio): if not (".wav" in audio):
if '.tdf' in audio: if ".tdf" in audio:
audio = audio.replace('.tdf', '.wav') audio = audio.replace(".tdf", ".wav")
else: else:
audio += '.wav' audio += ".wav"
transcript = ' '.join(line.split(sep)[1:]) transcript = " ".join(line.split(sep)[1:])
transcripts[audio] = transcript transcripts[audio] = transcript
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
glob_dir = os.path.join(wav_root_dir, '**/*.wav') glob_dir = os.path.join(wav_root_dir, "**/*.wav")
for record in glob(glob_dir, recursive=True): for record in glob(glob_dir, recursive=True):
record_file = os.path.basename(record) record_file = os.path.basename(record)
if record_file in transcripts: if record_file in transcripts:
@ -156,9 +159,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -178,25 +181,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
wav_filename=wav_filename, dict(
wav_filesize=os.path.getsize(wav_filename), wav_filename=wav_filename,
transcript=transcript, wav_filesize=os.path.getsize(wav_filename),
)) transcript=transcript,
)
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.') parser = get_importers_parser(
parser.add_argument(dest='target_dir') description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet') )
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') parser.add_argument(dest="target_dir")
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -204,9 +220,11 @@ if __name__ == "__main__":
def label_filter(label): def label_filter(label):
if CLI_ARGS.normalize: if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
try: try:

View File

@ -1,44 +1,38 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g. # ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
# ./data/swb/swb1_LDC97S62.tgz # ./data/swb/swb1_LDC97S62.tgz
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/ # from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
import codecs
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import fnmatch import fnmatch
import pandas import os
import subprocess import subprocess
import sys
import tarfile
import unicodedata import unicodedata
import wave import wave
import codecs
import tarfile
import requests
from util.importers import validate_label_eng as validate_label
import librosa import librosa
import soundfile # <= Has an external dependency on libsndfile import pandas
import requests
import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import validate_label_eng as validate_label
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03 # ARCHIVE_NAME refers to ISIP alignments from 01/29/03
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz' ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
ARCHIVE_URL = 'http://www.openslr.org/resources/5/' ARCHIVE_URL = "http://www.openslr.org/resources/5/"
ARCHIVE_DIR_NAME = 'LDC97S62' ARCHIVE_DIR_NAME = "LDC97S62"
LDC_DATASET = 'swb1_LDC97S62.tgz' LDC_DATASET = "swb1_LDC97S62.tgz"
def download_file(folder, url): def download_file(folder, url):
# https://stackoverflow.com/a/16696317/738515 # https://stackoverflow.com/a/16696317/738515
local_filename = url.split('/')[-1] local_filename = url.split("/")[-1]
full_filename = os.path.join(folder, local_filename) full_filename = os.path.join(folder, local_filename)
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(full_filename, 'wb') as f: with open(full_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
f.write(chunk) f.write(chunk)
return full_filename return full_filename
@ -46,7 +40,7 @@ def download_file(folder, url):
def maybe_download(archive_url, target_dir, ldc_dataset): def maybe_download(archive_url, target_dir, ldc_dataset):
# If archive file does not exist, download it... # If archive file does not exist, download it...
archive_path = os.path.join(target_dir, ldc_dataset) archive_path = os.path.join(target_dir, ldc_dataset)
ldc_path = archive_url+ldc_dataset ldc_path = archive_url + ldc_dataset
if not os.path.exists(target_dir): if not os.path.exists(target_dir):
print('No path "%s" - creating ...' % target_dir) print('No path "%s" - creating ...' % target_dir)
makedirs(target_dir) makedirs(target_dir)
@ -65,7 +59,7 @@ def _download_and_preprocess_data(data_dir):
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET)) archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
# Check swb1_LDC97S62.tgz then extract # Check swb1_LDC97S62.tgz then extract
assert(os.path.isfile(archive_path)) assert os.path.isfile(archive_path)
_extract(target_dir, archive_path) _extract(target_dir, archive_path)
# Transcripts # Transcripts
@ -73,8 +67,14 @@ def _download_and_preprocess_data(data_dir):
_extract(target_dir, transcripts_path) _extract(target_dir, transcripts_path)
# Check swb1_d1/2/3/4/swb_ms98_transcriptions # Check swb1_d1/2/3/4/swb_ms98_transcriptions
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"] expected_folders = [
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders])) "swb1_d1",
"swb1_d2",
"swb1_d3",
"swb1_d4",
"swb_ms98_transcriptions",
]
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
# Conditionally convert swb sph data to wav # Conditionally convert swb sph data to wav
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav") _maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
@ -83,10 +83,18 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav") _maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
# Conditionally split wav data # Conditionally split wav data
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav") d1 = _maybe_split_wav_and_sentences(
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav") target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav") )
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav") d2 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
)
d3 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
)
d4 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
)
swb_files = d1.append(d2).append(d3).append(d4) swb_files = d1.append(d2).append(d3).append(d4)
@ -118,14 +126,35 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
# Loop over sph files in source_dir and convert each to 16-bit PCM wav # Loop over sph files in source_dir and convert each to 16-bit PCM wav
for root, dirnames, filenames in os.walk(source_dir): for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.sph"): for filename in fnmatch.filter(filenames, "*.sph"):
for channel in ['1', '2']: for channel in ["1", "2"]:
sph_file = os.path.join(root, filename) sph_file = os.path.join(root, filename)
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav" wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename) wav_file = os.path.join(target_dir, wav_filename)
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav" temp_wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ "-temp.wav"
)
temp_wav_file = os.path.join(target_dir, temp_wav_filename) temp_wav_file = os.path.join(target_dir, temp_wav_filename)
print("converting {} to {}".format(sph_file, temp_wav_file)) print("converting {} to {}".format(sph_file, temp_wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file]) subprocess.check_call(
[
"sph2pipe",
"-c",
channel,
"-p",
"-f",
"rif",
sph_file,
temp_wav_file,
]
)
print("upsampling {} to {}".format(temp_wav_file, wav_file)) print("upsampling {} to {}".format(temp_wav_file, wav_file))
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True) audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
soundfile.write(wav_file, audioData, frameRate, "PCM_16") soundfile.write(wav_file, audioData, frameRate, "PCM_16")
@ -136,7 +165,7 @@ def _parse_transcriptions(trans_file):
segments = [] segments = []
with codecs.open(trans_file, "r", "utf-8") as fin: with codecs.open(trans_file, "r", "utf-8") as fin:
for line in fin: for line in fin:
if line.startswith("#") or len(line) <= 1: if line.startswith("#") or len(line) <= 1:
continue continue
tokens = line.split() tokens = line.split()
@ -150,15 +179,19 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript)
.decode("ascii", "ignore") .encode("ascii", "ignore")
.decode("ascii", "ignore")
)
segments.append({ segments.append(
"start_time": start_time, {
"stop_time": stop_time, "start_time": start_time,
"transcript": transcript, "stop_time": stop_time,
}) "transcript": transcript,
}
)
return segments return segments
@ -183,8 +216,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
segments = _parse_transcriptions(trans_file) segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file # Open wav corresponding to transcription file
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A'] channel = ("2", "1")[
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav" (os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
]
wav_filename = (
"sw0"
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(source_dir, wav_filename) wav_file = os.path.join(source_dir, wav_filename)
print("splitting {} according to {}".format(wav_file, trans_file)) print("splitting {} according to {}".format(wav_file, trans_file))
@ -200,26 +241,39 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
# Create wav segment filename # Create wav segment filename
start_time = segment["start_time"] start_time = segment["start_time"]
stop_time = segment["stop_time"] stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str( new_wav_filename = (
start_time) + "-" + str(stop_time) + ".wav" os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
if _is_wav_too_short(new_wav_filename): if _is_wav_too_short(new_wav_filename):
continue continue
new_wav_file = os.path.join(target_dir, new_wav_filename) new_wav_file = os.path.join(target_dir, new_wav_filename)
_split_wav(origAudio, start_time, stop_time, new_wav_file) _split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = os.path.getsize(new_wav_file) new_wav_filesize = os.path.getsize(new_wav_file)
transcript = segment["transcript"] transcript = segment["transcript"]
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript)) files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
# Close origAudio # Close origAudio
origAudio.close() origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _is_wav_too_short(wav_filename): def _is_wav_too_short(wav_filename):
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav'] short_wav_filenames = [
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
]
return wav_filename in short_wav_filenames return wav_filename in short_wav_filenames
@ -248,10 +302,24 @@ def _split_sets(filelist):
test_beg = dev_end test_beg = dev_end
test_end = len(filelist) test_end = len(filelist)
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end]) return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0): def _read_data_set(
filelist,
thread_count,
batch_size,
numcep,
numcontext,
stride=1,
offset=0,
next_index=lambda i: i + 1,
limit=0,
):
# Optionally apply dataset size limit # Optionally apply dataset size limit
if limit > 0: if limit > 0:
filelist = filelist.iloc[:limit] filelist = filelist.iloc[:limit]
@ -259,7 +327,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
filelist = filelist[offset::stride] filelist = filelist[offset::stride]
# Return DataSet # Return DataSet
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index) return DataSet(
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,70 +1,76 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
Use "python3 import_swc.py -h" for help Use "python3 import_swc.py -h" for help
''' """
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import re
import csv
import sox
import wave
import shutil
import random
import tarfile
import argparse import argparse
import progressbar import csv
import os
import random
import re
import shutil
import sys
import tarfile
import unicodedata import unicodedata
import wave
import xml.etree.cElementTree as ET import xml.etree.cElementTree as ET
from os import path
from glob import glob
from collections import Counter from collections import Counter
from glob import glob
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label import progressbar
from util.downloader import maybe_download, SIMPLE_BAR import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar" SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ['dutch', 'english', 'german'] LANGUAGES = ["dutch", "english", "german"]
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker'] FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
CHANNELS = 1 CHANNELS = 1
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
UNKNOWN = '<unknown>' UNKNOWN = "<unknown>"
AUDIO_PATTERN = 'audio*.ogg' AUDIO_PATTERN = "audio*.ogg"
WAV_NAME = 'audio.wav' WAV_NAME = "audio.wav"
ALIGNED_NAME = 'aligned.swc' ALIGNED_NAME = "aligned.swc"
SUBSTITUTIONS = { SUBSTITUTIONS = {
'german': [ "german": [
(re.compile(r'\$'), 'dollar'), (re.compile(r"\$"), "dollar"),
(re.compile(r''), 'euro'), (re.compile(r""), "euro"),
(re.compile(r'£'), 'pfund'), (re.compile(r"£"), "pfund"),
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '), (
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'), re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'), r"\1zehnhundert \2er ",
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'), ),
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'), (re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
(re.compile(r'punkt null null null punkt null null null'), 'millionen'), (
(re.compile(r'eins punkt null null null'), 'ein tausend'), re.compile(
(re.compile(r'punkt null null null'), 'tausend'), r"eins punkt null null null punkt null null null punkt null null null"
(re.compile(r'punkt null'), None) ),
"eine milliarde",
),
(
re.compile(
r"punkt null null null punkt null null null punkt null null null"
),
"milliarden",
),
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
(re.compile(r"eins punkt null null null"), "ein tausend"),
(re.compile(r"punkt null null null"), "tausend"),
(re.compile(r"punkt null"), None),
] ]
} }
DONT_NORMALIZE = { DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
'german': 'ÄÖÜäöüß'
}
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:')) PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
class Sample: class Sample:
@ -98,11 +104,14 @@ def get_sample_size(population_size):
margin_of_error = 0.01 margin_of_error = 0.01
fraction_picking = 0.50 fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99% z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2) numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0 sample_size = 0
for train_size in range(population_size, 0, -1): for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \ denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
(margin_of_error ** 2 * train_size) margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator) sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size: if 2 * sample_size + train_size <= population_size:
break break
@ -111,14 +120,16 @@ def get_sample_size(population_size):
def maybe_download_language(language): def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:] lang_upper = language[0].upper() + language[1:]
return maybe_download(SWC_ARCHIVE.format(language=lang_upper), return maybe_download(
CLI_ARGS.base_dir, SWC_ARCHIVE.format(language=lang_upper),
SWC_URL.format(language=lang_upper)) CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper),
)
def maybe_extract(data_dir, extracted_data, archive): def maybe_extract(data_dir, extracted_data, archive):
extracted = path.join(data_dir, extracted_data) extracted = os.path.join(data_dir, extracted_data)
if path.isdir(extracted): if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted)) print('Found directory "{}" - not extracting.'.format(extracted))
else: else:
print('Extracting "{}"...'.format(archive)) print('Extracting "{}"...'.format(archive))
@ -133,29 +144,29 @@ def maybe_extract(data_dir, extracted_data, archive):
def ignored(node): def ignored(node):
if node is None: if node is None:
return False return False
if node.tag == 'ignored': if node.tag == "ignored":
return True return True
return ignored(node.find('..')) return ignored(node.find(".."))
def read_token(token): def read_token(token):
texts, start, end = [], None, None texts, start, end = [], None, None
notes = token.findall('n') notes = token.findall("n")
if len(notes) > 0: if len(notes) > 0:
for note in notes: for note in notes:
attributes = note.attrib attributes = note.attrib
if start is None and 'start' in attributes: if start is None and "start" in attributes:
start = int(attributes['start']) start = int(attributes["start"])
if 'end' in attributes: if "end" in attributes:
token_end = int(attributes['end']) token_end = int(attributes["end"])
if end is None or token_end > end: if end is None or token_end > end:
end = token_end end = token_end
if 'pronunciation' in attributes: if "pronunciation" in attributes:
t = attributes['pronunciation'] t = attributes["pronunciation"]
texts.append(t) texts.append(t)
elif 'text' in token.attrib: elif "text" in token.attrib:
texts.append(token.attrib['text']) texts.append(token.attrib["text"])
return start, end, ' '.join(texts) return start, end, " ".join(texts)
def in_alphabet(alphabet, c): def in_alphabet(alphabet, c):
@ -163,10 +174,12 @@ def in_alphabet(alphabet, c):
ALPHABETS = {} ALPHABETS = {}
def get_alphabet(language): def get_alphabet(language):
if language in ALPHABETS: if language in ALPHABETS:
return ALPHABETS[language] return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet') alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
alphabet = Alphabet(alphabet_path) if alphabet_path else None alphabet = Alphabet(alphabet_path) if alphabet_path else None
ALPHABETS[language] = alphabet ALPHABETS[language] = alphabet
return alphabet return alphabet
@ -176,27 +189,35 @@ def label_filter(label, language):
label = label.translate(PRE_FILTER) label = label.translate(PRE_FILTER)
label = validate_label(label) label = validate_label(label)
if label is None: if label is None:
return None, 'validation' return None, "validation"
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else [] substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions: for pattern, replacement in substitutions:
if replacement is None: if replacement is None:
if pattern.match(label): if pattern.match(label):
return None, 'substitution rule' return None, "substitution rule"
else: else:
label = pattern.sub(replacement, label) label = pattern.sub(replacement, label)
chars = [] chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else '' dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language) alphabet = get_alphabet(language)
for c in label: for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c): if (
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") CLI_ARGS.normalize
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if not in_alphabet(alphabet, sc): if not in_alphabet(alphabet, sc):
return None, 'illegal character' return None, "illegal character"
chars.append(sc) chars.append(sc)
label = ''.join(chars) label = "".join(chars)
label = validate_label(label) label = validate_label(label)
return label, 'validation' if label is None else None return label, "validation" if label is None else None
def collect_samples(base_dir, language): def collect_samples(base_dir, language):
@ -207,7 +228,9 @@ def collect_samples(base_dir, language):
samples = [] samples = []
reasons = Counter() reasons = Counter()
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'): def add_sample(
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
):
if p_start is not None and p_end is not None and p_text is not None: if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start duration = p_end - p_start
text, filter_reason = label_filter(p_text, language) text, filter_reason = label_filter(p_text, language)
@ -217,53 +240,67 @@ def collect_samples(base_dir, language):
p_reason = filter_reason p_reason = filter_reason
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN: elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
skip = True skip = True
p_reason = 'unknown speaker' p_reason = "unknown speaker"
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN: elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
skip = True skip = True
p_reason = 'unknown article' p_reason = "unknown article"
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long: elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True skip = True
p_reason = 'exceeded duration' p_reason = "exceeded duration"
elif int(duration / 30) < len(text): elif int(duration / 30) < len(text):
skip = True skip = True
p_reason = 'too short to decode' p_reason = "too short to decode"
elif duration / len(text) < 10: elif duration / len(text) < 10:
skip = True skip = True
p_reason = 'length duration ratio' p_reason = "length duration ratio"
if skip: if skip:
reasons[p_reason] += 1 reasons[p_reason] += 1
else: else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)) samples.append(
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
)
elif p_start is None or p_end is None: elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1 reasons["missing timestamps"] += 1
else: else:
reasons['missing text'] += 1 reasons["missing text"] += 1
print('Collecting samples...') print("Collecting samples...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots): for root in bar(roots):
wav_path = path.join(root, WAV_NAME) wav_path = os.path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME)) aligned = ET.parse(path.join(root, ALIGNED_NAME))
article = UNKNOWN article = UNKNOWN
speaker = UNKNOWN speaker = UNKNOWN
for prop in aligned.iter('prop'): for prop in aligned.iter("prop"):
attributes = prop.attrib attributes = prop.attrib
if 'key' in attributes and 'value' in attributes: if "key" in attributes and "value" in attributes:
if attributes['key'] == 'DC.identifier': if attributes["key"] == "DC.identifier":
article = attributes['value'] article = attributes["value"]
elif attributes['key'] == 'reader.name': elif attributes["key"] == "reader.name":
speaker = attributes['value'] speaker = attributes["value"]
for sentence in aligned.iter('s'): for sentence in aligned.iter("s"):
if ignored(sentence): if ignored(sentence):
continue continue
split = False split = False
tokens = list(map(read_token, sentence.findall('t'))) tokens = list(map(read_token, sentence.findall("t")))
sample_start, sample_end, token_texts, sample_texts = None, None, [], [] sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens: for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text): if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts), add_sample(
p_reason='has numbers') wav_path,
sample_start, sample_end, token_texts, sample_texts = None, None, [], [] article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="has numbers",
)
sample_start, sample_end, token_texts, sample_texts = (
None,
None,
[],
[],
)
continue continue
if sample_start is None: if sample_start is None:
sample_start = token_start sample_start = token_start
@ -271,20 +308,37 @@ def collect_samples(base_dir, language):
continue continue
token_texts.append(token_text) token_texts.append(token_text)
if token_end is not None: if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0: if (
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts), token_start != sample_start
p_reason='split') and token_end - sample_start > CLI_ARGS.max_duration > 0
):
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split",
)
sample_start = sample_end sample_start = sample_end
sample_texts = [] sample_texts = []
split = True split = True
sample_end = token_end sample_end = token_end
sample_texts.extend(token_texts) sample_texts.extend(token_texts)
token_texts = [] token_texts = []
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts), add_sample(
p_reason='split' if split else 'complete') wav_path,
print('Skipped samples:') article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split" if split else "complete",
)
print("Skipped samples:")
for reason, n in reasons.most_common(): for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n)) print(" - {}: {}".format(reason, n))
return samples return samples
@ -294,8 +348,8 @@ def maybe_convert_one_to_wav(entry):
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner() combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = path.join(root, WAV_NAME) output_wav = os.path.join(root, WAV_NAME)
if path.isfile(output_wav): if os.path.isfile(output_wav):
return return
files = sorted(glob(path.join(root, AUDIO_PATTERN))) files = sorted(glob(path.join(root, AUDIO_PATTERN)))
try: try:
@ -304,18 +358,18 @@ def maybe_convert_one_to_wav(entry):
elif len(files) > 1: elif len(files) > 1:
wav_files = [] wav_files = []
for i, file in enumerate(files): for i, file in enumerate(files):
wav_path = path.join(root, 'audio{}.wav'.format(i)) wav_path = os.path.join(root, "audio{}.wav".format(i))
transformer.build(file, wav_path) transformer.build(file, wav_path)
wav_files.append(wav_path) wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files)) combiner.set_input_format(file_type=["wav"] * len(wav_files))
combiner.build(wav_files, output_wav, 'concatenate') combiner.build(wav_files, output_wav, "concatenate")
except sox.core.SoxError: except sox.core.SoxError:
return return
def maybe_convert_to_wav(base_dir): def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir)) roots = list(os.walk(base_dir))
print('Converting and joining source audio files...') print("Converting and joining source audio files...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool() tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)): for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
@ -335,53 +389,66 @@ def assign_sub_sets(samples):
sample_set.extend(speakers.pop(0)) sample_set.extend(speakers.pop(0))
train_set = sum(speakers, []) train_set = sum(speakers, [])
if len(train_set) == 0: if len(train_set) == 0:
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data') print(
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
)
random.seed(42) # same source data == same output random.seed(42) # same source data == same output
random.shuffle(samples) random.shuffle(samples)
for index, sample in enumerate(samples): for index, sample in enumerate(samples):
if index < sample_size: if index < sample_size:
sample.sub_set = 'dev' sample.sub_set = "dev"
elif index < 2 * sample_size: elif index < 2 * sample_size:
sample.sub_set = 'test' sample.sub_set = "test"
else: else:
sample.sub_set = 'train' sample.sub_set = "train"
else: else:
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]: for sub_set, sub_set_samples in [
("train", train_set),
("dev", sample_sets[0]),
("test", sample_sets[1]),
]:
for sample in sub_set_samples: for sample in sub_set_samples:
sample.sub_set = sub_set sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items(): for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60) t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print('Sub-set "{}" with {} samples (duration: {:.2f} h)' print(
.format(sub_set, len(sub_set_samples), t)) 'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
sub_set, len(sub_set_samples), t
)
)
def create_sample_dirs(language): def create_sample_dirs(language):
print('Creating sample directories...') print("Creating sample directories...")
for set_name in ['train', 'dev', 'test']: for set_name in ["train", "dev", "test"]:
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name) dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
if not path.isdir(dir_path): if not os.path.isdir(dir_path):
os.mkdir(dir_path) os.mkdir(dir_path)
def split_audio_files(samples, language): def split_audio_files(samples, language):
print('Splitting audio files...') print("Splitting audio files...")
sub_sets = Counter() sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items() src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files): for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start) file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, 'r') as src_wav_file: with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate() rate = src_wav_file.getframerate()
for sample in file_samples: for sample in file_samples:
index = sub_sets[sample.sub_set] index = sub_sets[sample.sub_set]
sample_wav_path = path.join(CLI_ARGS.base_dir, sample_wav_path = os.path.join(
language + '-' + sample.sub_set, CLI_ARGS.base_dir,
'sample-{0:06d}.wav'.format(index)) language + "-" + sample.sub_set,
"sample-{0:06d}.wav".format(index),
)
sample.wav_path = sample_wav_path sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1 sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0)) src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0)) data = src_wav_file.readframes(
with wave.open(sample_wav_path, 'w') as sample_wav_file: int((sample.end - sample.start) * rate / 1000.0)
)
with wave.open(sample_wav_path, "w") as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels()) sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth()) sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate) sample_wav_file.setframerate(rate)
@ -391,22 +458,26 @@ def split_audio_files(samples, language):
def write_csvs(samples, language): def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items(): for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path) set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = path.abspath(CLI_ARGS.base_dir) base_dir = os.path.abspath(CLI_ARGS.base_dir)
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv') csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
print('Writing "{}"...'.format(csv_path)) print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file: with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES) writer = csv.DictWriter(
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
)
writer.writeheader() writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(
max_value=len(set_samples), widgets=SIMPLE_BAR
)
for sample in bar(set_samples): for sample in bar(set_samples):
row = { row = {
'wav_filename': path.relpath(sample.wav_path, base_dir), "wav_filename": os.path.relpath(sample.wav_path, base_dir),
'wav_filesize': path.getsize(sample.wav_path), "wav_filesize": os.path.getsize(sample.wav_path),
'transcript': sample.text "transcript": sample.text,
} }
if CLI_ARGS.add_meta: if CLI_ARGS.add_meta:
row['article'] = sample.article row["article"] = sample.article
row['speaker'] = sample.speaker row["speaker"] = sample.speaker
writer.writerow(row) writer.writerow(row)
@ -414,8 +485,8 @@ def cleanup(archive, language):
if not CLI_ARGS.keep_archive: if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive)) print('Removing archive "{}"...'.format(archive))
os.remove(archive) os.remove(archive)
language_dir = path.join(CLI_ARGS.base_dir, language) language_dir = os.path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir): if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir)) print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir) shutil.rmtree(language_dir)
@ -433,34 +504,75 @@ def prepare_language(language):
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora') parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
parser.add_argument('base_dir', help='Directory containing all data') parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES))) parser.add_argument(
parser.add_argument('--exclude_numbers', type=bool, default=True, "--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
help='If sequences with non-transliterated numbers should be excluded') )
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds') parser.add_argument(
parser.add_argument('--ignore_too_long', type=bool, default=False, "--exclude_numbers",
help='If samples exceeding max_duration should be removed') type=bool,
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') default=True,
help="If sequences with non-transliterated numbers should be excluded",
)
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--ignore_too_long",
type=bool,
default=False,
help="If samples exceeding max_duration should be removed",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
for language in LANGUAGES: for language in LANGUAGES:
parser.add_argument('--{}_alphabet'.format(language), parser.add_argument(
help='Exclude {} samples with characters not in provided alphabet file'.format(language)) "--{}_alphabet".format(language),
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns') help="Exclude {} samples with characters not in provided alphabet file".format(
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers') language
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles') ),
parser.add_argument('--keep_archive', type=bool, default=True, )
help='If downloaded archives should be kept') parser.add_argument(
parser.add_argument('--keep_intermediate', type=bool, default=False, "--add_meta", action="store_true", help="Adds article and speaker CSV columns"
help='If intermediate files should be kept') )
parser.add_argument(
"--exclude_unknown_speakers",
action="store_true",
help="Exclude unknown speakers",
)
parser.add_argument(
"--exclude_unknown_articles",
action="store_true",
help="Exclude unknown articles",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
parser.add_argument(
"--keep_intermediate",
type=bool,
default=False,
help="If intermediate files should be kept",
)
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all': if CLI_ARGS.language == "all":
for lang in LANGUAGES: for lang in LANGUAGES:
prepare_language(lang) prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES: elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language) prepare_language(CLI_ARGS.language)
else: else:
fail('Wrong language id') fail("Wrong language id")

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import pandas
import tarfile import tarfile
import unicodedata import unicodedata
import wave import wave
from glob import glob from glob import glob
from os import makedirs, path, remove, rmdir from os import makedirs, path, remove, rmdir
import pandas
from sox import Transformer from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from util.stm import parse_stm_file
from deepspeech_training.util.downloader import maybe_download
from deepspeech_training.util.stm import parse_stm_file
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data # Conditionally download data
@ -41,6 +35,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False) dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False) test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive): def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir # If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(path.join(data_dir, extracted_data)): if not gfile.Exists(path.join(data_dir, extracted_data)):
@ -48,6 +43,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir) tar.extractall(data_dir)
tar.close() tar.close()
def _maybe_convert_wav(data_dir, extracted_data): def _maybe_convert_wav(data_dir, extracted_data):
# Create extracted_data dir # Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data) extracted_dir = path.join(data_dir, extracted_data)
@ -61,6 +57,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
# Conditionally convert test sph to wav # Conditionally convert test sph to wav
_maybe_convert_wav_dataset(extracted_dir, "test") _maybe_convert_wav_dataset(extracted_dir, "test")
def _maybe_convert_wav_dataset(extracted_dir, data_set): def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Create source dir # Create source dir
source_dir = path.join(extracted_dir, data_set, "sph") source_dir = path.join(extracted_dir, data_set, "sph")
@ -84,6 +81,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Remove source_dir # Remove source_dir
rmdir(source_dir) rmdir(source_dir)
def _maybe_split_sentences(data_dir, extracted_data): def _maybe_split_sentences(data_dir, extracted_data):
# Create extracted_data dir # Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data) extracted_dir = path.join(data_dir, extracted_data)
@ -99,6 +97,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
return train_files, dev_files, test_files return train_files, dev_files, test_files
def _maybe_split_dataset(extracted_dir, data_set): def _maybe_split_dataset(extracted_dir, data_set):
# Create stm dir # Create stm dir
stm_dir = path.join(extracted_dir, data_set, "stm") stm_dir = path.join(extracted_dir, data_set, "stm")
@ -116,14 +115,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
# Open wav corresponding to stm_file # Open wav corresponding to stm_file
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav" wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
wav_file = path.join(wav_dir, wav_filename) wav_file = path.join(wav_dir, wav_filename)
origAudio = wave.open(wav_file,'r') origAudio = wave.open(wav_file, "r")
# Loop over stm_segments and split wav_file for each segment # Loop over stm_segments and split wav_file for each segment
for stm_segment in stm_segments: for stm_segment in stm_segments:
# Create wav segment filename # Create wav segment filename
start_time = stm_segment.start_time start_time = stm_segment.start_time
stop_time = stm_segment.stop_time stop_time = stm_segment.stop_time
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav" new_wav_filename = (
path.splitext(path.basename(stm_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = path.join(wav_dir, new_wav_filename) new_wav_file = path.join(wav_dir, new_wav_filename)
# If the wav segment filename does not exist create it # If the wav segment filename does not exist create it
@ -131,23 +137,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
_split_wav(origAudio, start_time, stop_time, new_wav_file) _split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = path.getsize(new_wav_file) new_wav_filesize = path.getsize(new_wav_file)
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)) files.append(
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
)
# Close origAudio # Close origAudio
origAudio.close() origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _split_wav(origAudio, start_time, stop_time, new_wav_file): def _split_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio.getframerate() frameRate = origAudio.getframerate()
origAudio.setpos(int(start_time*frameRate)) origAudio.setpos(int(start_time * frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate)) chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
chunkAudio = wave.open(new_wav_file,'w') chunkAudio = wave.open(new_wav_file, "w")
chunkAudio.setnchannels(origAudio.getnchannels()) chunkAudio.setnchannels(origAudio.getnchannels())
chunkAudio.setsampwidth(origAudio.getsampwidth()) chunkAudio.setsampwidth(origAudio.getsampwidth())
chunkAudio.setframerate(frameRate) chunkAudio.setframerate(frameRate)
chunkAudio.writeframes(chunkData) chunkAudio.writeframes(chunkData)
chunkAudio.close() chunkAudio.close()
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
NAME : LDC TIMIT Dataset NAME : LDC TIMIT Dataset
URL : https://catalog.ldc.upenn.edu/ldc93s1 URL : https://catalog.ldc.upenn.edu/ldc93s1
HOURS : 5 HOURS : 5
@ -8,29 +8,32 @@
AUTHORS : Garofolo, John, et al. AUTHORS : Garofolo, John, et al.
TYPE : LDC Membership TYPE : LDC Membership
LICENCE : LDC User Agreement LICENCE : LDC User Agreement
''' """
import errno import errno
import fnmatch
import os import os
from os import path import subprocess
import sys import sys
import tarfile import tarfile
import fnmatch from os import path
import pandas as pd import pandas as pd
import subprocess
def clean(word): def clean(word):
# LC ALL & strip punctuation which are not required # LC ALL & strip punctuation which are not required
new = word.lower().replace('.', '') new = word.lower().replace(".", "")
new = new.replace(',', '') new = new.replace(",", "")
new = new.replace(';', '') new = new.replace(";", "")
new = new.replace('"', '') new = new.replace('"', "")
new = new.replace('!', '') new = new.replace("!", "")
new = new.replace('?', '') new = new.replace("?", "")
new = new.replace(':', '') new = new.replace(":", "")
new = new.replace('-', '') new = new.replace("-", "")
return new return new
def _preprocess_data(args): def _preprocess_data(args):
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1 # Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
@ -40,16 +43,24 @@ def _preprocess_data(args):
if ignoreSASentences: if ignoreSASentences:
print("Using recommended ignore SA sentences") print("Using recommended ignore SA sentences")
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)") print(
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
)
else: else:
print("Using unrecommended setting to include SA sentences") print("Using unrecommended setting to include SA sentences")
datapath = args datapath = args
target = path.join(datapath, "TIMIT") target = path.join(datapath, "TIMIT")
print("Checking to see if data has already been extracted in given argument: %s", target) print(
"Checking to see if data has already been extracted in given argument: %s",
target,
)
if not path.isdir(target): if not path.isdir(target):
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath) print(
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
datapath,
)
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz") filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
if path.isfile(filepath): if path.isfile(filepath):
print("File found, extracting") print("File found, extracting")
@ -103,40 +114,58 @@ def _preprocess_data(args):
# if ignoreSAsentences we only want those without SA in the name # if ignoreSAsentences we only want those without SA in the name
# OR # OR
# if not ignoreSAsentences we want all to be added # if not ignoreSAsentences we want all to be added
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences): if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
if 'train' in full_wav.lower(): not ignoreSASentences
):
if "train" in full_wav.lower():
train_list_wavs.append(full_wav) train_list_wavs.append(full_wav)
train_list_trans.append(trans) train_list_trans.append(trans)
train_list_size.append(wav_filesize) train_list_size.append(wav_filesize)
elif 'test' in full_wav.lower(): elif "test" in full_wav.lower():
test_list_wavs.append(full_wav) test_list_wavs.append(full_wav)
test_list_trans.append(trans) test_list_trans.append(trans)
test_list_size.append(wav_filesize) test_list_size.append(wav_filesize)
else: else:
raise IOError raise IOError
a = {'wav_filename': train_list_wavs, a = {
'wav_filesize': train_list_size, "wav_filename": train_list_wavs,
'transcript': train_list_trans "wav_filesize": train_list_size,
} "transcript": train_list_trans,
}
c = {'wav_filename': test_list_wavs, c = {
'wav_filesize': test_list_size, "wav_filename": test_list_wavs,
'transcript': test_list_trans "wav_filesize": test_list_size,
} "transcript": test_list_trans,
}
all = {'wav_filename': train_list_wavs + test_list_wavs, all = {
'wav_filesize': train_list_size + test_list_size, "wav_filename": train_list_wavs + test_list_wavs,
'transcript': train_list_trans + test_list_trans "wav_filesize": train_list_size + test_list_size,
} "transcript": train_list_trans + test_list_trans,
}
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) df_all = pd.DataFrame(
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int) )
df_train = pd.DataFrame(
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_test = pd.DataFrame(
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_all.to_csv(
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_train.to_csv(
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_test.to_csv(
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
if __name__ == "__main__": if __name__ == "__main__":
_preprocess_data(sys.argv[1]) _preprocess_data(sys.argv[1])

View File

@ -1,52 +1,53 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import absolute_import, division, print_function import csv
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os import os
import re import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import unidecode
import zipfile
import sox
import subprocess import subprocess
import progressbar import zipfile
from multiprocessing import Pool from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path import progressbar
import sox
from util.downloader import maybe_download import unidecode
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 15 MAX_SECS = 15
ARCHIVE_NAME = '2019-04-11_fr_FR' ARCHIVE_NAME = "2019-04-11_fr_FR"
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip' ARCHIVE_URL = (
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
)
def _download_and_preprocess_data(target_dir, english_compatible=False): def _download_and_preprocess_data(target_dir, english_compatible=False):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL) archive_path = maybe_download(
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
)
# Conditionally extract archive data # Conditionally extract archive data
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path) _maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav # Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible) _maybe_convert_sets(
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
)
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path) print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path): if not os.path.isdir(extracted_path):
os.mkdir(extracted_path) os.mkdir(extracted_path)
@ -58,16 +59,20 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def one_sample(sample): def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """ """ Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path'] orig_filename = sample["path"]
# Storing wav files next to the wav ones - just with a different suffix # Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav" wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename) _maybe_convert_wav(orig_filename, wav_filename)
file_size = -1 file_size = -1
frames = 0 frames = 0
if path.exists(wav_filename): if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT)) frames = int(
label = sample['text'] subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = sample["text"]
rows = [] rows = []
@ -75,40 +80,41 @@ def one_sample(sample):
counter = get_counter() counter = get_counter()
if file_size == -1: if file_size == -1:
# Excluding samples that failed upon conversion # Excluding samples that failed upon conversion
counter['failed'] += 1 counter["failed"] += 1
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
counter['invalid_label'] += 1 counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
counter['too_short'] += 1 counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1 counter["too_long"] += 1
else: else:
# This one is good - keep it for the target CSV # This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label)) rows.append((wav_filename, file_size, label))
counter['all'] += 1 counter["all"] += 1
counter['total_time'] += frames counter["total_time"] += frames
return (counter, rows) return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data) extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one # override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv') target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
if os.path.isfile(target_csv_template): if os.path.isfile(target_csv_template):
return return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv') path_to_original_csv = os.path.join(extracted_dir, "data.csv")
with open(path_to_original_csv) as csv_f: with open(path_to_original_csv) as csv_f:
data = [ data = [
d for d in csv.DictReader(csv_f, delimiter=',') d
if float(d['duration']) <= MAX_SECS for d in csv.DictReader(csv_f, delimiter=",")
if float(d["duration"]) <= MAX_SECS
] ]
for line in data: for line in data:
line['path'] = os.path.join(extracted_dir, line['path']) line["path"] = os.path.join(extracted_dir, line["path"])
num_samples = len(data) num_samples = len(data)
rows = [] rows = []
@ -125,9 +131,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10% with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10% with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -136,7 +142,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader() test_writer.writeheader()
for i, item in enumerate(rows): for i, item in enumerate(rows):
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible)) transcript = validate_label(
cleanup_transcript(
item[2], english_compatible=english_compatible
)
)
if not transcript: if not transcript:
continue continue
wav_filename = os.path.join(target_dir, extracted_data, item[0]) wav_filename = os.path.join(target_dir, extracted_data, item[0])
@ -147,45 +157,53 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow(dict( writer.writerow(
wav_filename=wav_filename, dict(
wav_filesize=os.path.getsize(wav_filename), wav_filename=wav_filename,
transcript=transcript, wav_filesize=os.path.getsize(wav_filename),
)) transcript=transcript,
)
)
imported_samples = get_imported_samples(counter) imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples assert counter["all"] == num_samples
assert len(rows) == imported_samples assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS) print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(orig_filename, wav_filename): def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE) transformer.convert(samplerate=SAMPLE_RATE)
try: try:
transformer.build(orig_filename, wav_filename) transformer.build(orig_filename, wav_filename)
except sox.core.SoxError as ex: except sox.core.SoxError as ex:
print('SoX processing error', ex, orig_filename, wav_filename) print("SoX processing error", ex, orig_filename, wav_filename)
PUNCTUATIONS_REG = re.compile(r"\-,;!?.()\[\]*…—]") PUNCTUATIONS_REG = re.compile(r"\-,;!?.()\[\]*…—]")
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}') MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
def cleanup_transcript(text, english_compatible=False): def cleanup_transcript(text, english_compatible=False):
text = text.replace('', "'").replace('\u00A0', ' ') text = text.replace("", "'").replace("\u00A0", " ")
text = PUNCTUATIONS_REG.sub(' ', text) text = PUNCTUATIONS_REG.sub(" ", text)
text = MULTIPLE_SPACES_REG.sub(' ', text) text = MULTIPLE_SPACES_REG.sub(" ", text)
if english_compatible: if english_compatible:
text = unidecode.unidecode(text) text = unidecode.unidecode(text)
return text.strip().lower() return text.strip().lower()
def handle_args(): def handle_args():
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.') parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
parser.add_argument(dest='target_dir') parser.add_argument(dest="target_dir")
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.') parser.add_argument(
"--english-compatible",
action="store_true",
dest="english_compatible",
help="Remove diactrics and other non-ascii chars.",
)
return parser.parse_args() return parser.parse_args()

View File

@ -1,45 +1,40 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
Use "python3 import_tuda.py -h" for help Use "python3 import_tuda.py -h" for help
''' """
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import wave
import tarfile
import argparse import argparse
import progressbar import csv
import os
import tarfile
import unicodedata import unicodedata
import wave
import xml.etree.cElementTree as ET import xml.etree.cElementTree as ET
from os import path
from collections import Counter from collections import Counter
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2' import progressbar
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE) from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE) from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet
TUDA_VERSION = "v2"
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
TUDA_PACKAGE
)
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
CHANNELS = 1 CHANNELS = 1
SAMPLE_WIDTH = 2 SAMPLE_WIDTH = 2
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
def maybe_extract(archive): def maybe_extract(archive):
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE) extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if path.isdir(extracted): if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted)) print('Found directory "{}" - not extracting.'.format(extracted))
else: else:
print('Extracting "{}"...'.format(archive)) print('Extracting "{}"...'.format(archive))
@ -52,86 +47,100 @@ def maybe_extract(archive):
def check_and_prepare_sentence(sentence): def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace('co2', 'c o zwei') sentence = sentence.lower().replace("co2", "c o zwei")
chars = [] chars = []
for c in sentence: for c in sentence:
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)): if (
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") CLI_ARGS.normalize
and c not in "äöüß"
and (ALPHABET is None or not ALPHABET.has_char(c))
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if ALPHABET is not None and not ALPHABET.has_char(c): if ALPHABET is not None and not ALPHABET.has_char(c):
return None return None
chars.append(sc) chars.append(sc)
return validate_label(''.join(chars)) return validate_label("".join(chars))
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
try: try:
with wave.open(wav_path, 'r') as src_wav_file: with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate() rate = src_wav_file.getframerate()
channels = src_wav_file.getnchannels() channels = src_wav_file.getnchannels()
sample_width = src_wav_file.getsampwidth() sample_width = src_wav_file.getsampwidth()
milliseconds = int(src_wav_file.getnframes() * 1000 / rate) milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
if rate != SAMPLE_RATE: if rate != SAMPLE_RATE:
return False, 'wrong sample rate' return False, "wrong sample rate"
if channels != CHANNELS: if channels != CHANNELS:
return False, 'wrong number of channels' return False, "wrong number of channels"
if sample_width != SAMPLE_WIDTH: if sample_width != SAMPLE_WIDTH:
return False, 'wrong sample width' return False, "wrong sample width"
if milliseconds / len(sentence) < 30: if milliseconds / len(sentence) < 30:
return False, 'too short' return False, "too short"
if milliseconds > CLI_ARGS.max_duration > 0: if milliseconds > CLI_ARGS.max_duration > 0:
return False, 'too long' return False, "too long"
except wave.Error: except wave.Error:
return False, 'invalid wav file' return False, "invalid wav file"
except EOFError: except EOFError:
return False, 'premature EOF' return False, "premature EOF"
return True, 'OK' return True, "OK"
def write_csvs(extracted): def write_csvs(extracted):
sample_counter = 0 sample_counter = 0
reasons = Counter() reasons = Counter()
for sub_set in ['train', 'dev', 'test']: for sub_set in ["train", "dev", "test"]:
set_path = path.join(extracted, sub_set) set_path = os.path.join(extracted, sub_set)
set_files = os.listdir(set_path) set_files = os.listdir(set_path)
recordings = {} recordings = {}
for file in set_files: for file in set_files:
if file.endswith('.xml'): if file.endswith(".xml"):
recordings[file[:-4]] = [] recordings[file[:-4]] = []
for file in set_files: for file in set_files:
if file.endswith('.wav') and '_' in file: if file.endswith(".wav") and "_" in file:
prefix = file.split('_')[0] prefix = file.split("_")[0]
if prefix in recordings: if prefix in recordings:
recordings[prefix].append(file) recordings[prefix].append(file)
recordings = recordings.items() recordings = recordings.items()
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set)) csv_path = os.path.join(
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
)
print('Writing "{}"...'.format(csv_path)) print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file: with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES) writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader() writer.writeheader()
set_dir = path.join(extracted, sub_set) set_dir = os.path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings): for prefix, wav_names in bar(recordings):
xml_path = path.join(set_dir, prefix + '.xml') xml_path = os.path.join(set_dir, prefix + ".xml")
meta = ET.parse(xml_path).getroot() meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter('cleaned_sentence'))[0].text sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence) sentence = check_and_prepare_sentence(sentence)
if sentence is None: if sentence is None:
continue continue
for wav_name in wav_names: for wav_name in wav_names:
sample_counter += 1 sample_counter += 1
wav_path = path.join(set_path, wav_name) wav_path = os.path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence) keep, reason = check_wav_file(wav_path, sentence)
if keep: if keep:
writer.writerow({ writer.writerow(
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir), {
'wav_filesize': path.getsize(wav_path), "wav_filename": os.path.relpath(
'transcript': sentence.lower() wav_path, CLI_ARGS.base_dir
}) ),
"wav_filesize": os.path.getsize(wav_path),
"transcript": sentence.lower(),
}
)
else: else:
reasons[reason] += 1 reasons[reason] += 1
if len(reasons.keys()) > 0: if len(reasons.keys()) > 0:
print('Excluded samples:') print("Excluded samples:")
for reason, n in reasons.most_common(): for reason, n in reasons.most_common():
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter)) print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
@ -150,13 +159,29 @@ def download_and_prepare():
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)') parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
parser.add_argument('base_dir', help='Directory containing all data') parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds') parser.add_argument(
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones') "--max_duration",
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file') type=int,
parser.add_argument('--keep_archive', type=bool, default=True, default=10000,
help='If downloaded archives should be kept') help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--alphabet",
help="Exclude samples with characters not in provided alphabet file",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
return parser.parse_args() return parser.parse_args()

View File

@ -1,29 +1,22 @@
#!/usr/bin/env python #!/usr/bin/env python
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf # VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0. # Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html # as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os import os
import random import random
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re import re
from multiprocessing import Pool
from zipfile import ZipFile
import librosa import librosa
import progressbar import progressbar
from os import path from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from multiprocessing import Pool from deepspeech_training.util.importers import (
from util.downloader import maybe_download, SIMPLE_BAR get_counter,
from zipfile import ZipFile get_imported_samples,
print_import_report,
)
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
MAX_SECS = 10 MAX_SECS = 10
@ -37,7 +30,7 @@ ARCHIVE_URL = (
def _download_and_preprocess_data(target_dir): def _download_and_preprocess_data(target_dir):
# Making path absolute # Making path absolute
target_dir = path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
# Conditionally download data # Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL) archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data # Conditionally extract common voice data
@ -48,8 +41,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path): def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data) extracted_path = os.path.join(target_dir, extracted_data)
if not path.exists(extracted_path): if not os.path.exists(extracted_path):
print(f"No directory {extracted_path} - extracting archive...") print(f"No directory {extracted_path} - extracting archive...")
with ZipFile(archive_path, "r") as zipobj: with ZipFile(archive_path, "r") as zipobj:
# Extract all the contents of zip file in current directory # Extract all the contents of zip file in current directory
@ -59,15 +52,17 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48") extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt") txt_dir = os.path.join(target_dir, extracted_data, "txt")
directory = os.path.expanduser(extracted_dir) directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory))) srtd = len(sorted(os.listdir(directory)))
all_samples = [] all_samples = []
for target in sorted(os.listdir(directory)): for target in sorted(os.listdir(directory)):
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1])) all_samples += _maybe_prepare_set(
path.join(extracted_dir, os.path.split(target)[-1])
)
num_samples = len(all_samples) num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...") print(f"Converting wav files to {SAMPLE_RATE}hz...")
@ -81,6 +76,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
_write_csv(extracted_dir, txt_dir, target_dir) _write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample): def one_sample(sample):
if is_audio_file(sample): if is_audio_file(sample):
y, sr = librosa.load(sample, sr=16000) y, sr = librosa.load(sample, sr=16000)
@ -103,6 +99,7 @@ def _maybe_prepare_set(target_csv):
samples = new_samples samples = new_samples
return samples return samples
def _write_csv(extracted_dir, txt_dir, target_dir): def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file") print(f"Writing CSV file")
dset_abs_path = extracted_dir dset_abs_path = extracted_dir
@ -197,7 +194,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filepath): def is_audio_file(filepath):
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS) return any(
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,24 +1,19 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import codecs import codecs
import sys
import os import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import tarfile
import pandas
import re import re
import unicodedata import tarfile
import threading import threading
from multiprocessing.pool import ThreadPool import unicodedata
import urllib
from six.moves import urllib
from glob import glob from glob import glob
from multiprocessing.pool import ThreadPool
from os import makedirs, path from os import makedirs, path
import pandas
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from util.downloader import maybe_download from deepspeech_training.util.downloader import maybe_download
"""The number of jobs to run in parallel""" """The number of jobs to run in parallel"""
NUM_PARALLEL = 8 NUM_PARALLEL = 8
@ -26,8 +21,10 @@ NUM_PARALLEL = 8
"""Lambda function returns the filename of a path""" """Lambda function returns the filename of a path"""
filename_of = lambda x: path.split(x)[1] filename_of = lambda x: path.split(x)[1]
class AtomicCounter(object): class AtomicCounter(object):
"""A class that atomically increments a counter""" """A class that atomically increments a counter"""
def __init__(self, start_count=0): def __init__(self, start_count=0):
"""Initialize the counter """Initialize the counter
:param start_count: the number to start counting at :param start_count: the number to start counting at
@ -50,6 +47,7 @@ class AtomicCounter(object):
"""Returns the current value of the counter (not atomic)""" """Returns the current value of the counter (not atomic)"""
return self.__count return self.__count
def _parallel_downloader(voxforge_url, archive_dir, total, counter): def _parallel_downloader(voxforge_url, archive_dir, total, counter):
"""Generate a function to download a file based on given parameters """Generate a function to download a file based on given parameters
This works by currying the above given arguments into a closure This works by currying the above given arguments into a closure
@ -61,6 +59,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
:param counter: an atomic counter to keep track of # of downloaded files :param counter: an atomic counter to keep track of # of downloaded files
:return: a function that actually downloads a file given these params :return: a function that actually downloads a file given these params
""" """
def download(d): def download(d):
"""Binds voxforge_url, archive_dir, total, and counter into this scope """Binds voxforge_url, archive_dir, total, and counter into this scope
Downloads the given file Downloads the given file
@ -68,12 +67,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
of the file to download and file is the name of the file to download of the file to download and file is the name of the file to download
""" """
(i, file) = d (i, file) = d
download_url = voxforge_url + '/' + file download_url = voxforge_url + "/" + file
c = counter.increment() c = counter.increment()
print('Downloading file {} ({}/{})...'.format(i+1, c, total)) print("Downloading file {} ({}/{})...".format(i + 1, c, total))
maybe_download(filename_of(download_url), archive_dir, download_url) maybe_download(filename_of(download_url), archive_dir, download_url)
return download return download
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter): def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
"""Generate a function to extract a tar file based on given parameters """Generate a function to extract a tar file based on given parameters
This works by currying the above given arguments into a closure This works by currying the above given arguments into a closure
@ -86,6 +87,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
:param counter: an atomic counter to keep track of # of extracted files :param counter: an atomic counter to keep track of # of extracted files
:return: a function that actually extracts a tar file given these params :return: a function that actually extracts a tar file given these params
""" """
def extract(d): def extract(d):
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope """Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
Extracts the given file Extracts the given file
@ -95,58 +97,74 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
(i, archive) = d (i, archive) = d
if i < number_of_test: if i < number_of_test:
dataset_dir = path.join(data_dir, "test") dataset_dir = path.join(data_dir, "test")
elif i<number_of_test+number_of_dev: elif i < number_of_test + number_of_dev:
dataset_dir = path.join(data_dir, "dev") dataset_dir = path.join(data_dir, "dev")
else: else:
dataset_dir = path.join(data_dir, "train") dataset_dir = path.join(data_dir, "train")
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))): if not gfile.Exists(
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
):
c = counter.increment() c = counter.increment()
print('Extracting file {} ({}/{})...'.format(i+1, c, total)) print("Extracting file {} ({}/{})...".format(i + 1, c, total))
tar = tarfile.open(archive) tar = tarfile.open(archive)
tar.extractall(dataset_dir) tar.extractall(dataset_dir)
tar.close() tar.close()
return extract return extract
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir # Conditionally download data to data_dir
if not path.isdir(data_dir): if not path.isdir(data_dir):
makedirs(data_dir) makedirs(data_dir)
archive_dir = data_dir+"/archive" archive_dir = data_dir + "/archive"
if not path.isdir(archive_dir): if not path.isdir(archive_dir):
makedirs(archive_dir) makedirs(archive_dir)
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir)) print(
"Downloading Voxforge data set into {} if not already present...".format(
archive_dir
)
)
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit' voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
html_page = urllib.request.urlopen(voxforge_url) html_page = urllib.request.urlopen(voxforge_url)
soup = BeautifulSoup(html_page, 'html.parser') soup = BeautifulSoup(html_page, "html.parser")
# list all links # list all links
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']] refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
# download files in parallel # download files in parallel
print('{} files to download'.format(len(refs))) print("{} files to download".format(len(refs)))
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter()) downloader = _parallel_downloader(
voxforge_url, archive_dir, len(refs), AtomicCounter()
)
p = ThreadPool(NUM_PARALLEL) p = ThreadPool(NUM_PARALLEL)
p.map(downloader, enumerate(refs)) p.map(downloader, enumerate(refs))
# Conditionally extract data to dataset_dir # Conditionally extract data to dataset_dir
if not path.isdir(path.join(data_dir,"test")): if not path.isdir(os.path.join(data_dir, "test")):
makedirs(path.join(data_dir,"test")) makedirs(os.path.join(data_dir, "test"))
if not path.isdir(path.join(data_dir,"dev")): if not path.isdir(os.path.join(data_dir, "dev")):
makedirs(path.join(data_dir,"dev")) makedirs(os.path.join(data_dir, "dev"))
if not path.isdir(path.join(data_dir,"train")): if not path.isdir(os.path.join(data_dir, "train")):
makedirs(path.join(data_dir,"train")) makedirs(os.path.join(data_dir, "train"))
tarfiles = glob(path.join(archive_dir, "*.tgz")) tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
number_of_files = len(tarfiles) number_of_files = len(tarfiles)
number_of_test = number_of_files//100 number_of_test = number_of_files // 100
number_of_dev = number_of_files//100 number_of_dev = number_of_files // 100
# extract tars in parallel # extract tars in parallel
print("Extracting Voxforge data set into {} if not already present...".format(data_dir)) print(
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()) "Extracting Voxforge data set into {} if not already present...".format(
data_dir
)
)
extracter = _parallel_extracter(
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
)
p.map(extracter, enumerate(tarfiles)) p.map(extracter, enumerate(tarfiles))
# Generate data set # Generate data set
@ -156,42 +174,50 @@ def _download_and_preprocess_data(data_dir):
train_files = _generate_dataset(data_dir, "train") train_files = _generate_dataset(data_dir, "train")
# Write sets to disk as CSV files # Write sets to disk as CSV files
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False) train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False) dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False) test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
def _generate_dataset(data_dir, data_set): def _generate_dataset(data_dir, data_set):
extracted_dir = path.join(data_dir, data_set) extracted_dir = path.join(data_dir, data_set)
files = [] files = []
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")): for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
if path.isdir(path.join(promts_file[:-11],"wav")): if path.isdir(os.path.join(promts_file[:-11], "wav")):
with codecs.open(promts_file, 'r', 'utf-8') as f: with codecs.open(promts_file, "r", "utf-8") as f:
for line in f: for line in f:
id = line.split(' ')[0].split('/')[-1] id = line.split(" ")[0].split("/")[-1]
sentence = ' '.join(line.split(' ')[1:]) sentence = " ".join(line.split(" ")[1:])
sentence = re.sub("[^a-z']"," ",sentence.strip().lower()) sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
transcript = "" transcript = ""
for token in sentence.split(" "): for token in sentence.split(" "):
word = token.strip() word = token.strip()
if word!="" and word!=" ": if word != "" and word != " ":
transcript += word + " " transcript += word + " "
transcript = unicodedata.normalize("NFKD", transcript.strip()) \ transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", transcript.strip())
.decode("ascii", "ignore") .encode("ascii", "ignore")
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav") .decode("ascii", "ignore")
)
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
if gfile.Exists(wav_file): if gfile.Exists(wav_file):
wav_filesize = path.getsize(wav_file) wav_filesize = path.getsize(wav_file)
# remove audios that are shorter than 0.5s and longer than 20s. # remove audios that are shorter than 0.5s and longer than 20s.
# remove audios that are too short for transcript. # remove audios that are too short for transcript.
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \ if (
wav_filesize/len(transcript)>1400: (wav_filesize / 32000) > 0.5
files.append((path.abspath(wav_file), wav_filesize, transcript)) and (wav_filesize / 32000) < 20
and transcript != ""
and wav_filesize / len(transcript) > 1400
):
files.append(
(os.path.abspath(wav_file), wav_filesize, transcript)
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"]) return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__=="__main__":
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -1,15 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import tensorflow.compat.v1 as tfv1
import sys import sys
import tensorflow.compat.v1 as tfv1
def main(): def main():
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin: with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef() graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read()) graph_def.ParseFromString(fin.read())
print('\n'.join(sorted(set(n.op for n in graph_def.node)))) print("\n".join(sorted(set(n.op for n in graph_def.node))))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -3,19 +3,13 @@
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
Use "python3 build_sdb.py -h" for help Use "python3 build_sdb.py -h" for help
""" """
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse import argparse
import random
import sys
from util.sample_collections import samples_from_file, LabeledSample from deepspeech_training.util.audio import AUDIO_TYPE_PCM
from util.audio import AUDIO_TYPE_PCM from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
def play_sample(samples, index): def play_sample(samples, index):
@ -24,7 +18,7 @@ def play_sample(samples, index):
if CLI_ARGS.random: if CLI_ARGS.random:
index = random.randint(0, len(samples)) index = random.randint(0, len(samples))
elif index >= len(samples): elif index >= len(samples):
print('No sample with index {}'.format(CLI_ARGS.start)) print("No sample with index {}".format(CLI_ARGS.start))
sys.exit(1) sys.exit(1)
sample = samples[index] sample = samples[index]
print('Sample "{}"'.format(sample.sample_id)) print('Sample "{}"'.format(sample.sample_id))
@ -50,13 +44,28 @@ def play_collection():
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) ' parser = argparse.ArgumentParser(
'and DeepSpeech CSV files') description="Tool for playing samples from Sample Databases (SDB files) "
parser.add_argument('collection', help='Sample DB or CSV file to play samples from') "and DeepSpeech CSV files"
parser.add_argument('--start', type=int, default=0, )
help='Sample index to start at (negative numbers are relative to the end of the collection)') parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)') parser.add_argument(
parser.add_argument('--random', action='store_true', help='If samples should be played in random order') "--start",
type=int,
default=0,
help="Sample index to start at (negative numbers are relative to the end of the collection)",
)
parser.add_argument(
"--number",
type=int,
default=-1,
help="Number of samples to play (-1 for endless)",
)
parser.add_argument(
"--random",
action="store_true",
help="If samples should be played in random order",
)
return parser.parse_args() return parser.parse_args()
@ -70,5 +79,5 @@ if __name__ == "__main__":
try: try:
play_collection() play_collection()
except KeyboardInterrupt: except KeyboardInterrupt:
print(' Stopped') print(" Stopped")
sys.exit(0) sys.exit(0)

View File

@ -1,17 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
import argparse import argparse
import shutil import shutil
import sys
from util.text import Alphabet, UTF8Alphabet from deepspeech_training.util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet

View File

@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
.. code-block:: .. code-block::
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/ $ python3 -m venv $HOME/tmp/deepspeech-train-venv/
Once this command completes successfully, the environment will be ready to be activated. Once this command completes successfully, the environment will be ready to be activated.
@ -46,7 +46,7 @@ Install the required dependencies using ``pip3``\ :
.. code-block:: bash .. code-block:: bash
cd DeepSpeech cd DeepSpeech
pip3 install -r requirements.txt pip3 install -e .
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules: The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
@ -70,7 +70,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
.. code-block:: bash .. code-block:: bash
pip3 uninstall tensorflow pip3 uninstall tensorflow
pip3 install 'tensorflow-gpu==1.15.0' pip3 install 'tensorflow-gpu==1.15.2'
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_. Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.

158
evaluate.py Executable file → Normal file
View File

@ -2,155 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.evaluate_tools import calculate_and_print_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
if __name__ == '__main__': if __name__ == '__main__':
create_flags() try:
absl.app.run(main) from deepspeech_training import evaluate as ds_evaluate
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_evaluate.run_script()

View File

@ -10,13 +10,12 @@ import csv
import os import os
import sys import sys
from functools import partial
from six.moves import zip, range
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from deepspeech import Model from deepspeech import Model
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
from util.evaluate_tools import calculate_and_print_report from deepspeech_training.util.flags import create_flags
from util.flags import create_flags from functools import partial
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from six.moves import zip, range
r''' r'''
This module should be self-contained: This module should be self-contained:

View File

@ -2,19 +2,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import sys
import optuna
import absl.app import absl.app
from ds_ctcdecoder import Scorer import optuna
import sys
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from DeepSpeech import create_model from deepspeech_training.evaluate import evaluate
from evaluate import evaluate from deepspeech_training.train import create_model
from util.config import Config, initialize_globals from deepspeech_training.util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS from deepspeech_training.util.flags import create_flags, FLAGS
from util.logging import log_error from deepspeech_training.util.logging import log_error
from util.evaluate_tools import wer_cer_batch from deepspeech_training.util.evaluate_tools import wer_cer_batch
from ds_ctcdecoder import Scorer
def character_based(): def character_based():

View File

@ -1,24 +0,0 @@
# Main training requirements
tensorflow == 1.15.2
numpy == 1.18.1
progressbar2
six
pyxdg
attrdict
absl-py
semver
opuslib == 2.0.0
# Requirements for building native_client files
setuptools
# Requirements for importers
sox
bs4
pandas
requests
librosa
soundfile
# Requirements for optimizer
optuna

60
setup.py Normal file
View File

@ -0,0 +1,60 @@
from pathlib import Path
from setuptools import find_packages, setup
def main():
version_file = Path(__file__).parent / 'VERSION'
with open(str(version_file)) as fin:
version = fin.read().strip()
setup(
name='deepspeech_training',
version=version,
description='Training code for mozilla DeepSpeech',
url='https://github.com/mozilla/DeepSpeech',
author='Mozilla',
license='MPL-2.0',
# Classifiers help users find your project by categorizing it.
#
# For a list of valid classifiers, see https://pypi.org/classifiers/
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Topic :: Multimedia :: Sound/Audio :: Speech',
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
'Programming Language :: Python :: 3',
],
package_dir={'': 'training'},
packages=find_packages(where='training'),
python_requires='>=3.5, <4',
install_requires=[
'tensorflow == 1.15.2',
'numpy == 1.18.1',
'progressbar2',
'six',
'pyxdg',
'attrdict',
'absl-py',
'semver',
'opuslib == 2.0.0',
'optuna',
'sox',
'bs4',
'pandas',
'requests',
'librosa',
'soundfile',
],
# If there are data files included in your packages that need to be
# installed, specify them here.
package_data={
'deepspeech_training': [
'VERSION',
'GRAPH_VERSION',
],
},
)
if __name__ == '__main__':
main()

View File

@ -1,10 +1,29 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import os import functools
import pandas
from deepspeech_training.util.helpers import secs_to_hours
from pathlib import Path
def read_csvs(csv_files):
# Relative paths are relative to CSV location
def absolutify(csv, path):
path = Path(path)
if path.is_absolute():
return str(path)
return str(csv.parent / path)
sets = []
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True)
from util.helpers import secs_to_hours
from util.feeding import read_csvs
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -14,20 +33,16 @@ def main():
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels") parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample") parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
args = parser.parse_args() args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")] in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
csv_dataframe = read_csvs(in_files) csv_dataframe = read_csvs(in_files)
total_bytes = csv_dataframe['wav_filesize'].sum() total_bytes = csv_dataframe['wav_filesize'].sum()
total_files = len(csv_dataframe.index) total_files = len(csv_dataframe)
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
bytes_without_headers = total_bytes - 44 * total_files print('Total bytes:', total_bytes)
print('Total files:', total_files)
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8) print('Total time:', secs_to_hours(total_seconds))
print('total_bytes', total_bytes)
print('total_files', total_files)
print('bytes_without_headers', bytes_without_headers)
print('total_time', secs_to_hours(total_time))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
set -o pipefail set -o pipefail
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
which deepspeech which deepspeech

View File

@ -17,7 +17,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}") decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")

View File

@ -16,7 +16,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/

View File

@ -14,7 +14,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail set +o pipefail
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/

View File

@ -1,10 +1,14 @@
import unittest import unittest
from argparse import Namespace from argparse import Namespace
from .importers import validate_label_eng, get_validate_label from deepspeech_training.util.importers import validate_label_eng, get_validate_label
from pathlib import Path
def from_here(path):
here = Path(__file__)
return here.parent / path
class TestValidateLabelEng(unittest.TestCase): class TestValidateLabelEng(unittest.TestCase):
def test_numbers(self): def test_numbers(self):
label = validate_label_eng("this is a 1 2 3 test") label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None) self.assertEqual(label, None)
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
self.assertEqual(f('toto1234[{[{[]'), None) self.assertEqual(f('toto1234[{[{[]'), None)
def test_get_validate_label_missing(self): def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py') args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
f = get_validate_label(args) f = get_validate_label(args)
self.assertEqual(f, None) self.assertEqual(f, None)
def test_get_validate_label(self): def test_get_validate_label(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py') args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
f = get_validate_label(args) f = get_validate_label(args)
l = f('toto') l = f('toto')
self.assertEqual(l, 'toto') self.assertEqual(l, 'toto')

View File

@ -1,7 +1,7 @@
import unittest import unittest
import os import os
from .text import Alphabet from deepspeech_training.util.text import Alphabet
class TestAlphabetParsing(unittest.TestCase): class TestAlphabetParsing(unittest.TestCase):

View File

@ -0,0 +1 @@
../../GRAPH_VERSION

View File

@ -0,0 +1 @@
../../VERSION

View File

View File

@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version
from .util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

View File

@ -0,0 +1,936 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

View File

@ -5,7 +5,7 @@ import tempfile
import collections import collections
import numpy as np import numpy as np
from util.helpers import LimitingPool from .helpers import LimitingPool
DEFAULT_RATE = 16000 DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1 DEFAULT_CHANNELS = 1

View File

@ -2,8 +2,8 @@ import sys
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from util.flags import FLAGS from .flags import FLAGS
from util.logging import log_info, log_error, log_warn from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path): def _load_checkpoint(session, checkpoint_path):

View File

@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict from attrdict import AttrDict
from xdg import BaseDirectory as xdg from xdg import BaseDirectory as xdg
from util.flags import FLAGS from .flags import FLAGS
from util.gpu import get_available_gpus from .gpu import get_available_gpus
from util.logging import log_error from .logging import log_error
from util.text import Alphabet, UTF8Alphabet from .text import Alphabet, UTF8Alphabet
from util.helpers import parse_file_size from .helpers import parse_file_size
class ConfigSingleton: class ConfigSingleton:
_config = None _config = None

View File

@ -7,8 +7,8 @@ import numpy as np
from attrdict import AttrDict from attrdict import AttrDict
from util.flags import FLAGS from .flags import FLAGS
from util.text import levenshtein from .text import levenshtein
def pmap(fun, iterable): def pmap(fun, iterable):

View File

@ -8,13 +8,13 @@ import tensorflow as tf
from tensorflow.python.ops import gen_audio_ops as contrib_audio from tensorflow.python.ops import gen_audio_ops as contrib_audio
from util.config import Config from .config import Config
from util.text import text_to_char_array from .text import text_to_char_array
from util.flags import FLAGS from .flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
from util.sample_collections import samples_from_files from .sample_collections import samples_from_files
from util.helpers import remember_exception, MEGABYTE from .helpers import remember_exception, MEGABYTE
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None): def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):

View File

@ -4,7 +4,7 @@ import os
import re import re
import sys import sys
from util.helpers import secs_to_hours from .helpers import secs_to_hours
from collections import Counter from collections import Counter
def get_counter(): def get_counter():

View File

@ -3,7 +3,7 @@ from __future__ import print_function
import progressbar import progressbar
import sys import sys
from util.flags import FLAGS from .flags import FLAGS
# Logging functions # Logging functions

View File

@ -5,8 +5,8 @@ import json
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from util.helpers import MEGABYTE, GIGABYTE, Interleaved from .helpers import MEGABYTE, GIGABYTE, Interleaved
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
BIG_ENDIAN = 'big' BIG_ENDIAN = 'big'
INT_SIZE = 4 INT_SIZE = 4

View File

@ -1,6 +1,7 @@
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from util.sparse_image_warp import sparse_image_warp
from .sparse_image_warp import sparse_image_warp
def augment_freq_time_mask(spectrogram, def augment_freq_time_mask(spectrogram,
frequency_masking_para=30, frequency_masking_para=30,

View File

@ -0,0 +1,167 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import argparse
import errno
import gzip
import os
import platform
import six.moves.urllib as urllib
import stat
import subprocess
import sys
from pkg_resources import parse_version
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
}
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert artifact_name
assert branch_name is not None
assert branch_name
return TASKCLUSTER_SCHEME % {'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
def maybe_download_tc(target_dir, tc_url, progress=True):
def report_progress(count, block_size, total_size):
percent = (count * block_size * 100) // total_size
sys.stdout.write("\rDownloading: %d%%" % percent)
sys.stdout.flush()
if percent >= 100:
print('\n')
assert target_dir is not None
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
is_gzip = False
if not os.path.isfile(target_file):
print('Downloading %s ...' % tc_url)
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
is_gzip = headers.get('Content-Encoding') == 'gzip'
else:
print('File already exists: %s' % target_file)
if is_gzip:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)
frw.truncate()
return target_file
def maybe_download_tc_bin(**kwargs):
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
final_stat = os.stat(final_file)
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
parser.add_argument('--arch', required=False,
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
parser.add_argument('--artifact', required=False,
default='native_client.tar.xz',
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
parser.add_argument('--source', required=False, default=None,
help='Name of the TaskCluster scheme to use.')
parser.add_argument('--branch', required=False,
help='Branch name to use. Defaulting to current content of VERSION file.')
parser.add_argument('--decoder', action='store_true',
help='Get URL to ds_ctcdecoder Python package.')
args = parser.parse_args()
if not args.target and not args.decoder:
print('Pass either --target or --decoder.')
sys.exit(1)
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
is_ucs2 = sys.maxunicode < 0x10ffff
if not args.arch:
if is_arm:
args.arch = 'arm64' if is_64bit else 'arm'
elif is_mac:
args.arch = 'osx'
else:
args.arch = 'cpu'
if not args.branch:
version_string = read('../VERSION').strip()
ds_version = parse_version(version_string)
args.branch = "v{}".format(version_string)
else:
ds_version = parse_version(args.branch)
if args.decoder:
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(map(str, sys.version_info[0:2]))
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch
)
ctc_arch = args.arch + '-ctc'
print(get_tc_url(ctc_arch, artifact, args.branch))
sys.exit(0)
if args.source is not None:
if args.source in DEFAULT_SCHEMES:
global TASKCLUSTER_SCHEME
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
else:
print('No such scheme: %s' % args.source)
sys.exit(1)
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
if args.artifact == "convert_graphdef_memmapped_format":
convert_graph_file = os.path.join(args.target, args.artifact)
final_stat = os.stat(convert_graph_file)
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
if __name__ == '__main__':
main()

View File

@ -12,13 +12,13 @@ tflogging.set_verbosity(tflogging.ERROR)
import logging import logging
logging.getLogger('sox').setLevel(logging.ERROR) logging.getLogger('sox').setLevel(logging.ERROR)
from multiprocessing import Process, cpu_count from deepspeech_training.util.audio import AudioFile
from deepspeech_training.util.config import Config, initialize_globals
from deepspeech_training.util.feeding import split_audio_file
from deepspeech_training.util.flags import create_flags, FLAGS
from deepspeech_training.util.logging import log_error, log_info, log_progress, create_progressbar
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import Config, initialize_globals from multiprocessing import Process, cpu_count
from util.audio import AudioFile
from util.feeding import split_audio_file
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_info, log_progress, create_progressbar
def fail(message, code=1): def fail(message, code=1):
@ -27,8 +27,8 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path): def transcribe_file(audio_path, tlog_path):
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from util.checkpoints import load_or_init_graph from deepspeech_training.util.checkpoints import load_or_init_graph
initialize_globals() initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try: try:

View File

@ -1,159 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import absolute_import
import os
import subprocess
import csv
from threading import Thread
from time import time
from scipy.interpolate import spline
from six.moves import range
# Do this to be able to use without X
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
class GPUUsage(Thread):
def __init__(self, csvfile=None):
super(GPUUsage, self).__init__()
self._cmd = [ 'nvidia-smi', 'dmon', '-d', '1', '-s', 'pucvmet' ]
self._names = []
self._units = []
self._process = None
self._csv_output = csvfile or os.environ.get('ds_gpu_usage_csv', self.make_basename(prefix='ds-gpu-usage', extension='csv'))
def get_git_desc(self):
return subprocess.check_output(['git', 'describe', '--always', '--abbrev']).strip()
def make_basename(self, prefix, extension):
# Let us assume that this code is executed in the current git clone
return '%s.%s.%s.%s' % (prefix, self.get_git_desc(), int(time()), extension)
def stop(self):
if not self._process:
print("Trying to stop nvidia-smi but no more process, please fix.")
return
print("Ending nvidia-smi monitoring: PID", self._process.pid)
self._process.terminate()
print("Ended nvidia-smi monitoring ...")
def run(self):
print("Starting nvidia-smi monitoring")
# If the system has no CUDA setup, then this will fail.
try:
self._process = subprocess.Popen(self._cmd, stdout=subprocess.PIPE)
except OSError as ex:
print("Unable to start monitoring, check your environment:", ex)
return
writer = None
with open(self._csv_output, 'w') as f:
for line in iter(self._process.stdout.readline, ''):
d = self.ingest(line)
if line.startswith('# '):
if len(self._names) == 0:
self._names = d
writer = csv.DictWriter(f, delimiter=str(','), quotechar=str('"'), fieldnames=d)
writer.writeheader()
continue
if len(self._units) == 0:
self._units = d
continue
else:
assert len(self._names) == len(self._units)
assert len(d) == len(self._names)
assert len(d) > 1
writer.writerow(self.merge_line(d))
f.flush()
def ingest(self, line):
return map(lambda x: x.replace('-', '0'), filter(lambda x: len(x) > 0, map(lambda x: x.strip(), line.split(' ')[1:])))
def merge_line(self, line):
return dict(zip(self._names, line))
class GPUUsageChart():
def __init__(self, source, basename=None):
self._rows = [ 'pwr', 'temp', 'sm', 'mem']
self._titles = {
'pwr': "Power (W)",
'temp': "Temperature (°C)",
'sm': "Streaming Multiprocessors (%)",
'mem': "Memory (%)"
}
self._data = { }.fromkeys(self._rows)
self._csv = source
self._basename = basename or os.environ.get('ds_gpu_usage_charts', 'gpu_usage_%%s_%d.png' % int(time.time()))
# This should make sure we start from anything clean.
plt.close("all")
try:
self.read()
for plot in self._rows:
self.produce_plot(plot)
except IOError as ex:
print("Unable to read", ex)
def append_data(self, row):
for bucket, value in row.iteritems():
if not bucket in self._rows:
continue
if not self._data[bucket]:
self._data[bucket] = {}
gpu = int(row['gpu'])
if not self._data[bucket].has_key(gpu):
self._data[bucket][gpu] = [ value ]
else:
self._data[bucket][gpu] += [ value ]
def read(self):
print("Reading data from", self._csv)
with open(self._csv, 'r') as f:
for r in csv.DictReader(f):
self.append_data(r)
def produce_plot(self, key, with_spline=True):
png = self._basename % (key, )
print("Producing plot for", key, "as", png)
fig, axis = plt.subplots()
data = self._data[key]
if data is None:
print("Data was empty, aborting")
return
x = list(range(len(data[0])))
if with_spline:
x = map(lambda x: float(x), x)
x_sm = np.array(x)
x_smooth = np.linspace(x_sm.min(), x_sm.max(), 300)
for gpu, y in data.iteritems():
if with_spline:
y = map(lambda x: float(x), y)
y_sm = np.array(y)
y_smooth = spline(x, y, x_smooth, order=1)
axis.plot(x_smooth, y_smooth, label='GPU %d' % (gpu))
else:
axis.plot(x, y, label='GPU %d' % (gpu))
axis.legend(loc="upper right", frameon=False)
axis.set_xlabel("Time (s)")
axis.set_ylabel("%s" % self._titles[key])
fig.set_size_inches(24, 18)
plt.title("GPU Usage: %s" % self._titles[key])
plt.savefig(png, dpi=100)
plt.close(fig)

View File

@ -1,168 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division from __future__ import absolute_import, division, print_function
import argparse
import platform
import subprocess
import sys
import os
import errno
import stat
import gzip
import six.moves.urllib as urllib
from pkg_resources import parse_version
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
}
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert artifact_name
assert branch_name is not None
assert branch_name
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
def maybe_download_tc(target_dir, tc_url, progress=True):
def report_progress(count, block_size, total_size):
percent = (count * block_size * 100) // total_size
sys.stdout.write("\rDownloading: %d%%" % percent)
sys.stdout.flush()
if percent >= 100:
print('\n')
assert target_dir is not None
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
is_gzip = False
if not os.path.isfile(target_file):
print('Downloading %s ...' % tc_url)
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
is_gzip = headers.get('Content-Encoding') == 'gzip'
else:
print('File already exists: %s' % target_file)
if is_gzip:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)
frw.truncate()
return target_file
def maybe_download_tc_bin(**kwargs):
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
final_stat = os.stat(final_file)
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
parser.add_argument('--arch', required=False,
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
parser.add_argument('--artifact', required=False,
default='native_client.tar.xz',
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
parser.add_argument('--source', required=False, default=None,
help='Name of the TaskCluster scheme to use.')
parser.add_argument('--branch', required=False,
help='Branch name to use. Defaulting to current content of VERSION file.')
parser.add_argument('--decoder', action='store_true',
help='Get URL to ds_ctcdecoder Python package.')
args = parser.parse_args()
if not args.target and not args.decoder:
print('Pass either --target or --decoder.')
exit(1)
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
is_ucs2 = sys.maxunicode < 0x10ffff
if not args.arch:
if is_arm:
args.arch = 'arm64' if is_64bit else 'arm'
elif is_mac:
args.arch = 'osx'
else:
args.arch = 'cpu'
if not args.branch:
version_string = read('../VERSION').strip()
ds_version = parse_version(version_string)
args.branch = "v{}".format(version_string)
else:
ds_version = parse_version(args.branch)
if args.decoder:
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(map(str, sys.version_info[0:2]))
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch
)
ctc_arch = args.arch + '-ctc'
print(get_tc_url(ctc_arch, artifact, args.branch))
exit(0)
if args.source is not None:
if args.source in DEFAULT_SCHEMES:
global TASKCLUSTER_SCHEME
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
else:
print('No such scheme: %s' % args.source)
exit(1)
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
if args.artifact == "convert_graphdef_memmapped_format":
convert_graph_file = os.path.join(args.target, args.artifact)
final_stat = os.stat(convert_graph_file)
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
if __name__ == '__main__': if __name__ == '__main__':
main() try:
from deepspeech_training.util import taskcluster as dsu_taskcluster
except ImportError:
print('Training package is not installed. See training documentation.')
raise
dsu_taskcluster.main()