Merge pull request #2856 from reuben/training-install

Package training code to avoid sys.path hacks
This commit is contained in:
Reuben Morais 2020-03-31 15:42:42 +02:00 committed by GitHub
commit 83d22e591b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
78 changed files with 3276 additions and 2699 deletions

4
.isort.cfg Normal file
View File

@ -0,0 +1,4 @@
[settings]
line_length=80
multi_line_output=3
default_section=FIRSTPARTY

View File

@ -9,7 +9,7 @@ python:
jobs:
include:
- stage: cardboard linter
- name: cardboard linter
install:
- pip install --upgrade cardboardlint pylint
script:
@ -17,9 +17,10 @@ jobs:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
fi
- stage: python unit tests
- name: python unit tests
install:
- pip install --upgrade -r requirements_tests.txt
- pip install --upgrade -r requirements_tests.txt;
pip install --upgrade .
script:
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
python -m unittest;

View File

@ -2,934 +2,11 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version, ExceptionBox
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
else:
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
from tensorflow.python.framework.ops import Tensor, Operation
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__':
create_flags()
absl.app.run(main)
try:
from deepspeech_training import train as ds_train
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_train.run_script()

View File

@ -150,7 +150,7 @@ COPY . /DeepSpeech/
WORKDIR /DeepSpeech
RUN pip3 --no-cache-dir install -r requirements.txt
RUN pip3 --no-cache-dir install .
# Link DeepSpeech native_client libs to tf folder
RUN ln -s /DeepSpeech/native_client /tensorflow

View File

@ -1,53 +1,69 @@
#!/usr/bin/env python
'''
"""
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
Use "python3 build_sdb.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
"""
import argparse
import progressbar
from util.downloader import SIMPLE_BAR
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
from util.sample_collections import samples_from_files, DirectSDBWriter
from deepspeech_training.util.audio import (
AUDIO_TYPE_OPUS,
AUDIO_TYPE_WAV,
change_audio_types,
)
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.sample_collections import (
DirectSDBWriter,
samples_from_files,
)
AUDIO_TYPE_LOOKUP = {
'wav': AUDIO_TYPE_WAV,
'opus': AUDIO_TYPE_OPUS
}
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
def build_sdb():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
with DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
) as sdb_writer:
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
for sample in bar(
change_audio_types(
samples, audio_type=audio_type, processes=CLI_ARGS.workers
)
):
sdb_writer.add(sample)
def handle_args():
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
'from DeepSpeech CSV files and other SDB files')
parser.add_argument('sources', nargs='+',
help='Source CSV and/or SDB files - '
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
'already ordered from shortest to longest.')
parser.add_argument('target', help='SDB file to create')
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
help='Audio representation inside target SDB')
parser.add_argument('--workers', type=int, default=None,
help='Number of encoding SDB workers')
parser.add_argument('--unlabeled', action='store_true',
help='If to build an SDB with unlabeled (audio only) samples - '
'typically used for building noise augmentation corpora')
parser = argparse.ArgumentParser(
description="Tool for building Sample Databases (SDB files) "
"from DeepSpeech CSV files and other SDB files"
)
parser.add_argument(
"sources",
nargs="+",
help="Source CSV and/or SDB files - "
"Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
"already ordered from shortest to longest.",
)
parser.add_argument("target", help="SDB file to create")
parser.add_argument(
"--audio-type",
default="opus",
choices=AUDIO_TYPE_LOOKUP.keys(),
help="Audio representation inside target SDB",
)
parser.add_argument(
"--workers", type=int, default=None, help="Number of encoding SDB workers"
)
parser.add_argument(
"--unlabeled",
action="store_true",
help="If to build an SDB with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora",
)
return parser.parse_args()

View File

@ -1,11 +0,0 @@
#!/usr/bin/env python
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsage
gu = GPUUsage()
gu.start()

View File

@ -1,10 +0,0 @@
#!/usr/bin/env python
import sys
import os
sys.path.append(os.path.abspath('.'))
from util.gpu_usage import GPUUsageChart
GPUUsageChart(sys.argv[1], sys.argv[2])

View File

@ -1,20 +1,21 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow.compat.v1 as tfv1
import sys
import tensorflow.compat.v1 as tfv1
from google.protobuf import text_format
def main():
# Load and export as string
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read())
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
fout.write(text_format.MessageToString(graph_def))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,23 +1,17 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import pandas
import os
import tarfile
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -25,9 +19,9 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'aidatatang_200zh')
main_folder = os.path.join(target_dir, "aidatatang_200zh")
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')):
for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
extract(targz, os.path.dirname(targz))
# Folder structure is now:
@ -46,9 +40,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt')
transcripts_path = os.path.join(
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
)
with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path):
set_files = []
@ -57,33 +53,39 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav
wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n')
transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
for subset in ('train', 'dev', 'test'):
print('Loading {} set samples...'.format(subset))
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav'))
for subset in ("train", "dev", "test"):
print("Loading {} set samples...".format(subset))
subset_files = load_set(
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
)
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train':
durations = (df['wav_filesize'] - 44) / 16000 / 2
if subset == "train":
durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv))
dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False)
def main():
# https://www.openslr.org/62/
parser = get_importers_parser(description='Import aidatatang_200zh corpus')
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import aidatatang_200zh corpus")
parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,23 +1,17 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import os
import tarfile
import pandas
from deepspeech_training.util.importers import get_importers_parser
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -25,10 +19,10 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'data_aishell')
main_folder = os.path.join(target_dir, "data_aishell")
wav_archives_folder = os.path.join(main_folder, 'wav')
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')):
wav_archives_folder = os.path.join(main_folder, "wav")
for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
extract(targz, main_folder)
# Folder structure is now:
@ -45,9 +39,11 @@ def preprocess_data(tgz_file, target_dir):
# Since the transcripts themselves can contain spaces, we split on space but
# only once, then build a mapping from file name to transcript
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt')
transcripts_path = os.path.join(
main_folder, "transcript", "aishell_transcript_v0.8.txt"
)
with open(transcripts_path) as fin:
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
def load_set(glob_path):
set_files = []
@ -56,33 +52,37 @@ def preprocess_data(tgz_file, target_dir):
wav_filename = wav
wav_filesize = os.path.getsize(wav)
transcript_key = os.path.splitext(os.path.basename(wav))[0]
transcript = transcripts[transcript_key].strip('\n')
transcript = transcripts[transcript_key].strip("\n")
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
for subset in ('train', 'dev', 'test'):
print('Loading {} set samples...'.format(subset))
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav'))
for subset in ("train", "dev", "test"):
print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
# Trim train set to under 10s by removing the last couple hundred samples
if subset == 'train':
durations = (df['wav_filesize'] - 44) / 16000 / 2
if subset == "train":
durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv))
dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False)
def main():
# http://www.openslr.org/33/
parser = get_importers_parser(description='Import AISHELL corpus')
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import AISHELL corpus")
parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,34 +1,35 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import sox
import tarfile
import os
import subprocess
import progressbar
import tarfile
from glob import glob
from os import path
from multiprocessing import Pool
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
import progressbar
import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
print_import_report,
)
from deepspeech_training.util.importers import validate_label_eng as validate_label
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 10
ARCHIVE_DIR_NAME = 'cv_corpus_v1'
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz'
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "cv_corpus_v1"
ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
ARCHIVE_URL = (
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
)
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data
@ -36,56 +37,70 @@ def _download_and_preprocess_data(target_dir):
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
else:
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
for source_csv in glob(path.join(extracted_dir, '*.csv')):
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
extracted_dir = os.path.join(target_dir, extracted_data)
for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
_maybe_convert_set(
extracted_dir,
source_csv,
os.path.join(target_dir, os.path.split(source_csv)[-1]),
)
def one_sample(sample):
mp3_filename = sample[0]
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
frames = int(
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
)
file_size = -1
if path.exists(wav_filename):
if os.path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = validate_label(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
print()
if path.exists(target_csv):
if os.path.exists(target_csv):
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
return
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
@ -93,14 +108,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
with open(source_csv) as source_csv_file:
reader = csv.DictReader(source_csv_file)
for row in reader:
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
# Mutable counters for the concurrent embedded routine
counter = get_counter()
num_samples = len(samples)
rows = []
print('Importing mp3 files...')
print("Importing mp3 files...")
pool = Pool()
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
@ -112,21 +127,28 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
pool.join()
print('Writing "%s"...' % target_csv)
with open(target_csv, 'w') as target_csv_file:
with open(target_csv, "w") as target_csv_file:
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows):
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
@ -134,5 +156,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
except sox.core.SoxError:
pass
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,90 +1,96 @@
#!/usr/bin/env python
'''
"""
Broadly speaking, this script takes the audio downloaded from Common Voice
for a certain language, in addition to the *.tsv files output by CorporaCreator,
and the script formats the data and transcripts to be in a state usable by
DeepSpeech.py
Use "python3 import_cv2.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
"""
import csv
import sox
import os
import subprocess
import progressbar
import unicodedata
from os import path
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from util.text import Alphabet
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 10
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
for dataset in ["train", "test", "dev", "validated", "other"]:
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
if os.path.isfile(input_tsv):
print("Loading TSV file: ", input_tsv)
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
mp3_filename = sample[0]
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
mp3_filename += ".mp3"
# Storing wav files next to the mp3 ones - just with a different suffix
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
_maybe_convert_wav(mp3_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter_fun(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((os.path.split(wav_filename)[-1], file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
output_csv = os.path.join(
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
)
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
# Get audiofile path and transcript for each sentence in tsv
samples = []
with open(input_tsv, encoding='utf-8') as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter='\t')
with open(input_tsv, encoding="utf-8") as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter="\t")
for row in reader:
samples.append((path.join(audio_dir, row['path']), row['sentence']))
samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
counter = get_counter()
num_samples = len(samples)
@ -101,26 +107,38 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
pool.close()
pool.join()
with open(output_csv, 'w', encoding='utf-8') as output_csv_file:
print('Writing CSV file for DeepSpeech.py as: ', output_csv)
with open(output_csv, "w", encoding="utf-8") as output_csv_file:
print("Writing CSV file for DeepSpeech.py as: ", output_csv)
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
for filename, file_size, transcript in bar(rows):
if space_after_every_character:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)})
writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": " ".join(transcript),
}
)
else:
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
writer.writerow(
{
"wav_filename": filename,
"wav_filesize": file_size,
"transcript": transcript,
}
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(mp3_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
@ -130,24 +148,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
if __name__ == "__main__":
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
PARSER.add_argument(
"--audio_dir",
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
)
PARSER.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
PARSER.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
PARSER.add_argument(
"--space_after_every_character",
action="store_true",
help="To help transcript join by white space",
)
PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
AUDIO_DIR = (
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
)
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
def label_filter_fun(label):
if PARAMS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:

View File

@ -1,25 +1,20 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import codecs
import fnmatch
import os
import subprocess
import sys
import unicodedata
import librosa
import pandas
import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import validate_label_eng as validate_label
# Prerequisite: Having the sph2pipe tool in your PATH:
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import os
import pandas
import subprocess
import unicodedata
import librosa
import soundfile # <= Has an external dependency on libsndfile
from util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
@ -29,33 +24,55 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
# Conditionally split Fisher wav data
all_2004 = _split_wav_and_sentences(data_dir,
original_data="fisher-2004-wav",
converted_data="fisher-2004-split-wav",
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"))
all_2005 = _split_wav_and_sentences(data_dir,
original_data="fisher-2005-wav",
converted_data="fisher-2005-split-wav",
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"))
all_2004 = _split_wav_and_sentences(
data_dir,
original_data="fisher-2004-wav",
converted_data="fisher-2004-split-wav",
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
)
all_2005 = _split_wav_and_sentences(
data_dir,
original_data="fisher-2005-wav",
converted_data="fisher-2005-split-wav",
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
)
# The following files have incorrect transcripts that are much longer than
# their audio source. The result is that we end up with more labels than time
# slices, which breaks CTC.
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct"
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those"
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want"
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
all_2004.loc[
all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
"transcript",
] = "correct"
all_2004.loc[
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
"transcript",
] = "that's one of those"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
"transcript",
] = "they don't want"
all_2005.loc[
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
"transcript",
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
# The following file is just a short sound and not at all transcribed like provided.
# So we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")]
all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
]
# The following file is far too long and would ruin our training batch size.
# So we just exclude it.
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")]
all_2005 = all_2005[
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
]
# The following file is too large for its transcript, so we just exclude it.
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")]
all_2004 = all_2004[
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
]
# Conditionally split Fisher data into train/validation/test sets
train_2004, dev_2004, test_2004 = _split_sets(all_2004)
@ -71,6 +88,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
def _maybe_convert_wav(data_dir, original_data, converted_data):
source_dir = os.path.join(data_dir, original_data)
target_dir = os.path.join(data_dir, converted_data)
@ -88,10 +106,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
for filename in fnmatch.filter(filenames, "*.sph"):
sph_file = os.path.join(root, filename)
for channel in ["1", "2"]:
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav"
wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "_c"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename)
print("converting {} to {}".format(sph_file, wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file])
subprocess.check_call(
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
)
def _parse_transcriptions(trans_file):
segments = []
@ -109,18 +135,23 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array
# expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
transcript = (
unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
segments.append({
"start_time": start_time,
"stop_time": stop_time,
"speaker": speaker,
"transcript": transcript,
})
segments.append(
{
"start_time": start_time,
"stop_time": stop_time,
"speaker": speaker,
"transcript": transcript,
}
)
return segments
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
trans_dir = os.path.join(data_dir, trans_data)
source_dir = os.path.join(data_dir, original_data)
@ -137,43 +168,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]]
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames]
wav_filenames = [
os.path.splitext(os.path.basename(trans_file))[0]
+ "_c"
+ channel
+ ".wav"
for channel in ["1", "2"]
]
wav_files = [
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
]
print("splitting {} according to {}".format(wav_files, trans_file))
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files]
origAudios = [
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
]
# Loop over segments and split wav_file for each segment
for segment in segments:
# Create wav segment filename
start_time = segment["start_time"]
stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
new_wav_filename = (
os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = os.path.join(target_dir, new_wav_filename)
channel = 0 if segment["speaker"] == "A:" else 1
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file)
_split_and_resample_wav(
origAudios[channel], start_time, stop_time, new_wav_file
)
new_wav_filesize = os.path.getsize(new_wav_file)
transcript = validate_label(segment["transcript"])
if transcript != None:
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
def _split_audio(origAudio, start_time, stop_time):
audioData, frameRate = origAudio
nChannels = len(audioData.shape)
startIndex = int(start_time * frameRate)
stopIndex = int(stop_time * frameRate)
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex]
return (
audioData[startIndex:stopIndex]
if 1 == nChannels
else audioData[:, startIndex:stopIndex]
)
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio[1]
chunkData = _split_audio(origAudio, start_time, stop_time)
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
@ -187,9 +248,12 @@ def _split_sets(filelist):
test_beg = dev_end
test_end = len(filelist)
return (filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end])
return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import numpy as np
import pandas
import os
import tarfile
import numpy as np
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -26,7 +20,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS')
main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
# Folder structure is now:
# - ST-CMDS-20170001_1-OS/
@ -39,16 +33,16 @@ def preprocess_data(tgz_file, target_dir):
for wav in glob.glob(glob_path):
wav_filename = wav
wav_filesize = os.path.getsize(wav)
txt_filename = os.path.splitext(wav_filename)[0] + '.txt'
with open(txt_filename, 'r') as fin:
txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
with open(txt_filename, "r") as fin:
transcript = fin.read()
set_files.append((wav_filename, wav_filesize, transcript))
return set_files
# Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, '*.wav'))
all_files = load_set(os.path.join(main_folder, "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True)
df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df))
np.random.seed(12345)
@ -61,29 +55,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000]
train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv')
print('Saving train set into {}...'.format(dest_csv))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv')
print('Saving dev set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv')
print('Saving test set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False)
def main():
# https://www.openslr.org/38/
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import math
import urllib
import logging
from util.importers import get_importers_parser, get_validate_label
import math
import os
import subprocess
from os import path
import urllib
from pathlib import Path
import swifter
import pandas as pd
from sox import Transformer
import swifter
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
__version__ = "0.1.0"
_logger = logging.getLogger(__name__)
@ -40,9 +34,7 @@ def parse_args(args):
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = get_importers_parser(
description="Imports GramVaani data for Deep Speech"
)
parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
parser.add_argument(
"--version",
action="version",
@ -82,6 +74,7 @@ def parse_args(args):
)
return parser.parse_args(args)
def setup_logging(level):
"""Setup basic logging
Args:
@ -92,6 +85,7 @@ def setup_logging(level):
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
)
class GramVaaniCSV:
"""GramVaaniCSV representing a GramVaani dataset.
Args:
@ -107,8 +101,17 @@ class GramVaaniCSV:
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
data = pd.read_csv(
os.path.abspath(csv_filename),
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
usecols=["audio_url","transcript","audio_length"],
names=[
"piece_id",
"audio_url",
"transcript_labelled",
"transcript",
"labels",
"content_filename",
"audio_length",
"user_id",
],
usecols=["audio_url", "transcript", "audio_length"],
skiprows=[0],
engine="python",
encoding="utf-8",
@ -119,6 +122,7 @@ class GramVaaniCSV:
_logger.info("Parsed %d lines csv file." % len(data))
return data
class GramVaaniDownloader:
"""GramVaaniDownloader downloads a GramVaani dataset.
Args:
@ -138,15 +142,17 @@ class GramVaaniDownloader:
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
"""
mp3_directory = self._pre_download()
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
self.data.swifter.apply(
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
)
return mp3_directory
def _pre_download(self):
mp3_directory = path.join(self.target_dir, "mp3")
if not path.exists(self.target_dir):
mp3_directory = os.path.join(self.target_dir, "mp3")
if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir)
if not path.exists(mp3_directory):
if not os.path.exists(mp3_directory):
_logger.info("Creating directory...%s", mp3_directory)
os.mkdir(mp3_directory)
return mp3_directory
@ -154,13 +160,14 @@ class GramVaaniDownloader:
def _download(self, audio_url, transcript, audio_length, mp3_directory):
if audio_url == "audio_url":
return
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
if not path.exists(mp3_filename):
mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
if not os.path.exists(mp3_filename):
_logger.debug("Downloading mp3 file...%s", audio_url)
urllib.request.urlretrieve(audio_url, mp3_filename)
else:
_logger.debug("Already downloaded mp3 file...%s", audio_url)
class GramVaaniConverter:
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
Args:
@ -181,37 +188,53 @@ class GramVaaniConverter:
wav_directory (os.path): The directory into which the associated wav's were downloaded
"""
wav_directory = self._pre_convert()
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
if not path.exists(wav_filename):
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
wav_filename = os.path.join(
wav_directory,
os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
)
if not os.path.exists(wav_filename):
_logger.debug(
"Converting mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
transformer = Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
transformer.convert(
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
)
transformer.build(str(mp3_filename), str(wav_filename))
else:
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
_logger.debug(
"Already converted mp3 file %s to wav file %s"
% (mp3_filename, wav_filename)
)
return wav_directory
def _pre_convert(self):
wav_directory = path.join(self.target_dir, "wav")
if not path.exists(self.target_dir):
wav_directory = os.path.join(self.target_dir, "wav")
if not os.path.exists(self.target_dir):
_logger.info("Creating directory...%s", self.target_dir)
os.mkdir(self.target_dir)
if not path.exists(wav_directory):
if not os.path.exists(wav_directory):
_logger.info("Creating directory...%s", wav_directory)
os.mkdir(wav_directory)
return wav_directory
class GramVaaniDataSets:
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
self.target_dir = target_dir
self.wav_directory = wav_directory
self.csv_data = gram_vaani_csv.data
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.valid = pd.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]
)
self.train = pd.DataFrame(
columns=["wav_filename", "wav_filesize", "transcript"]
)
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
def create(self):
self._convert_csv_data_to_raw_data()
@ -220,30 +243,45 @@ class GramVaaniDataSets:
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
train_size, dev_size, test_size = self._calculate_data_set_sizes()
self.train = self.valid.loc[0:train_size]
self.dev = self.valid.loc[train_size:train_size+dev_size]
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
self.dev = self.valid.loc[train_size : train_size + dev_size]
self.test = self.valid.loc[
train_size + dev_size : train_size + dev_size + test_size
]
def _convert_csv_data_to_raw_data(self):
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
["audio_url","transcript","audio_length"]
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
["audio_url", "transcript", "audio_length"]
].swifter.apply(
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
axis=1,
raw=True,
)
self.raw.reset_index()
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
if audio_url == "audio_url":
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
mp3_filename = os.path.basename(audio_url)
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
wav_relative_filename = os.path.join(
"wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
)
wav_filesize = os.path.getsize(
os.path.join(self.target_dir, wav_relative_filename)
)
transcript = validate_label(transcript)
if None == transcript:
transcript = ""
return pd.Series([wav_relative_filename, wav_filesize, transcript])
return pd.Series([wav_relative_filename, wav_filesize, transcript])
def _is_valid_raw_rows(self):
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
is_valid_raw_row = [
(is_valid_raw_transcript & is_valid_raw_wav_frame)
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
is_valid_raw_transcripts, is_valid_raw_wav_frames
)
]
series = pd.Series(is_valid_raw_row)
return series
@ -252,16 +290,29 @@ class GramVaaniDataSets:
def _is_valid_raw_wav_frames(self):
transcripts = [str(transcript) for transcript in self.raw.transcript]
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
wav_filepaths = [
os.path.join(self.target_dir, str(wav_filename))
for wav_filename in self.raw.wav_filename
]
wav_frames = [
int(
subprocess.check_output(
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
)
)
for wav_filepath in wav_filepaths
]
is_valid_raw_wav_frames = [
self._is_wav_frame_valid(wav_frame, transcript)
for wav_frame, transcript in zip(wav_frames, transcripts)
]
return pd.Series(is_valid_raw_wav_frames)
def _is_wav_frame_valid(self, wav_frame, transcript):
is_wav_frame_valid = True
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
if int(wav_frame / SAMPLE_RATE * 1000 / 10 / 2) < len(str(transcript)):
is_wav_frame_valid = False
elif wav_frame/SAMPLE_RATE > MAX_SECS:
elif wav_frame / SAMPLE_RATE > MAX_SECS:
is_wav_frame_valid = False
return is_wav_frame_valid
@ -280,7 +331,14 @@ class GramVaaniDataSets:
def _save(self, dataset):
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
dataframe = getattr(self, dataset)
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
dataframe.to_csv(
dataset_path,
index=False,
encoding="utf-8",
escapechar="\\",
quoting=csv.QUOTE_MINIMAL,
)
def main(args):
"""Main entry point allowing external calls
@ -304,4 +362,5 @@ def main(args):
datasets.save()
_logger.info("Finished GramVaani importer...")
main(sys.argv[1:])

View File

@ -1,28 +1,33 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import sys
import pandas
from util.downloader import maybe_download
from deepspeech_training.util.downloader import maybe_download
def _download_and_preprocess_data(data_dir):
# Conditionally download data
LDC93S1_BASE = "LDC93S1"
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav")
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt")
local_file = maybe_download(
LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
)
trans_file = maybe_download(
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
)
with open(trans_file, "r") as fin:
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '')
transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
".", ""
)
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
columns=["wav_filename", "wav_filesize", "transcript"])
df = pandas.DataFrame(
data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
columns=["wav_filename", "wav_filesize", "transcript"],
)
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,33 +1,39 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import pandas
import progressbar
import os
import subprocess
import sys
import tarfile
import unicodedata
import pandas
import progressbar
from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile
from deepspeech_training.util.downloader import maybe_download
SAMPLE_RATE = 16000
def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir))
print(
"Downloading Librivox data set (55GB) into {} if not already present...".format(
data_dir
)
)
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz"
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz"
TRAIN_CLEAN_100_URL = (
"http://www.openslr.org/resources/12/train-clean-100.tar.gz"
)
TRAIN_CLEAN_360_URL = (
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
)
TRAIN_OTHER_500_URL = (
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
)
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
@ -35,12 +41,20 @@ def _download_and_preprocess_data(data_dir):
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
def filename_of(x): return os.path.split(x)[1]
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL)
def filename_of(x):
return os.path.split(x)[1]
train_clean_100 = maybe_download(
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
)
bar.update(0)
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL)
train_clean_360 = maybe_download(
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
)
bar.update(1)
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL)
train_other_500 = maybe_download(
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
)
bar.update(2)
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
@ -48,9 +62,13 @@ def _download_and_preprocess_data(data_dir):
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
bar.update(4)
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL)
test_clean = maybe_download(
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
)
bar.update(5)
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL)
test_other = maybe_download(
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
)
bar.update(6)
# Conditionally extract LibriSpeech data
@ -61,11 +79,17 @@ def _download_and_preprocess_data(data_dir):
LIBRIVOX_DIR = "LibriSpeech"
work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100)
_maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
)
bar.update(0)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360)
_maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
)
bar.update(1)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500)
_maybe_extract(
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
)
bar.update(2)
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
@ -91,28 +115,48 @@ def _download_and_preprocess_data(data_dir):
# data_dir/LibriSpeech/split-wav/1-2-2.txt
# ...
print("Converting FLAC to WAV and splitting transcriptions...")
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav")
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
train_100 = _convert_audio_and_split_sentences(
work_dir, "train-clean-100", "train-clean-100-wav"
)
bar.update(0)
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav")
train_360 = _convert_audio_and_split_sentences(
work_dir, "train-clean-360", "train-clean-360-wav"
)
bar.update(1)
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav")
train_500 = _convert_audio_and_split_sentences(
work_dir, "train-other-500", "train-other-500-wav"
)
bar.update(2)
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav")
dev_clean = _convert_audio_and_split_sentences(
work_dir, "dev-clean", "dev-clean-wav"
)
bar.update(3)
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav")
dev_other = _convert_audio_and_split_sentences(
work_dir, "dev-other", "dev-other-wav"
)
bar.update(4)
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav")
test_clean = _convert_audio_and_split_sentences(
work_dir, "test-clean", "test-clean-wav"
)
bar.update(5)
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav")
test_other = _convert_audio_and_split_sentences(
work_dir, "test-other", "test-other-wav"
)
bar.update(6)
# Write sets to disk as CSV files
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False)
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False)
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False)
train_100.to_csv(
os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
)
train_360.to_csv(
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
)
train_500.to_csv(
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
)
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
@ -120,6 +164,7 @@ def _download_and_preprocess_data(data_dir):
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(os.path.join(data_dir, extracted_data)):
@ -127,6 +172,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir)
tar.close()
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
source_dir = os.path.join(extracted_dir, data_set)
target_dir = os.path.join(extracted_dir, dest_dir)
@ -149,20 +195,22 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
# We also convert the corresponding FLACs to WAV in the same pass
files = []
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, '*.trans.txt'):
for filename in fnmatch.filter(filenames, "*.trans.txt"):
trans_filename = os.path.join(root, filename)
with codecs.open(trans_filename, "r", "utf-8") as fin:
for line in fin:
# Parse each segment line
first_space = line.find(" ")
seqid, transcript = line[:first_space], line[first_space+1:]
seqid, transcript = line[:first_space], line[first_space + 1 :]
# We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array
# expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
transcript = (
unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
transcript = transcript.lower().strip()
@ -177,7 +225,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,44 +1,39 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import argparse
import csv
import os
import re
import sox
import zipfile
import subprocess
import progressbar
import unicodedata
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import zipfile
from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 10
ARCHIVE_DIR_NAME = 'lingua_libre'
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip'
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "lingua_libre"
ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data
@ -46,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files and convert ogg data to wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -58,57 +54,70 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else:
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
ogg_filename = sample[0]
# Storing wav files next to the ogg ones - just with a different suffix
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
_maybe_convert_wav(ogg_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1])
rows = []
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
)
if os.path.isfile(target_csv_template):
return
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
for record in glob(glob_dir, recursive=True):
record_file = record.replace(ogg_root_dir + os.path.sep, '')
record_file = record.replace(ogg_root_dir + os.path.sep, "")
if record_filter(record_file):
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
samples.append(
(
os.path.join(ogg_root_dir, record_file),
os.path.splitext(os.path.basename(record_file))[0],
)
)
counter = get_counter()
num_samples = len(samples)
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -139,7 +148,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
transcript = validate_label(item[2])
if not transcript:
continue
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav'))
wav_filename = os.path.join(
ogg_root_dir, item[0].replace(".ogg", ".wav")
)
i_mod = i % 10
if i_mod == 0:
writer = test_writer
@ -147,38 +158,63 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(ogg_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
transformer.build(ogg_filename, wav_filename)
except sox.core.SoxError as ex:
print('SoX processing error', ex, ogg_filename, wav_filename)
print("SoX processing error", ex, ogg_filename, wav_filename)
def handle_args():
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
parser.add_argument(dest='target_dir')
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items')
parser = get_importers_parser(
description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
)
parser.add_argument(dest="target_dir")
parser.add_argument(
"--qId", type=int, required=True, help="LinguaLibre language qId"
)
parser.add_argument(
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
)
parser.add_argument(
"--english-name", type=str, required=True, help="Enligh name of the language"
)
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--bogus-records",
type=argparse.FileType("r"),
required=False,
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -191,15 +227,17 @@ if __name__ == "__main__":
def record_filter(path):
if any(regex.match(path) for regex in bogus_regexes):
print('Reject', path)
print("Reject", path)
return False
return True
def label_filter(label):
if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:
@ -208,6 +246,14 @@ if __name__ == "__main__":
label = None
return label
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
ARCHIVE_NAME = ARCHIVE_NAME.format(
qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
ARCHIVE_URL = ARCHIVE_URL.format(
qId=CLI_ARGS.qId,
iso639_3=CLI_ARGS.iso639_3,
language_English_name=CLI_ARGS.english_name,
)
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)

View File

@ -1,43 +1,37 @@
#!/usr/bin/env python3
# pylint: disable=invalid-name
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import os
import subprocess
import progressbar
import unicodedata
import tarfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import unicodedata
from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
import progressbar
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 15
ARCHIVE_DIR_NAME = '{language}'
ARCHIVE_NAME = '{language}.tgz'
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "{language}"
ARCHIVE_NAME = "{language}.tgz"
ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data
@ -48,8 +42,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -65,9 +59,13 @@ def one_sample(sample):
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1])
counter = get_counter()
rows = []
@ -75,27 +73,30 @@ def one_sample(sample):
if file_size == -1:
# Excluding samples that failed upon conversion
print("conversion failure", wav_filename)
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
)
if os.path.isfile(target_csv_template):
return
@ -103,14 +104,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv')
glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
for record in glob(glob_dir, recursive=True):
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop
if any(
map(lambda sk: sk in record, SKIP_LIST)
): # pylint: disable=cell-var-from-loop
continue
with open(record, 'r') as rec:
with open(record, "r") as rec:
for re in rec.readlines():
re = re.strip().split('|')
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav')
re = re.strip().split("|")
audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
transcript = re[2]
samples.append((audio, transcript))
@ -129,9 +132,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -151,39 +154,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=os.path.relpath(wav_filename, extracted_dir),
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=os.path.relpath(wav_filename, extracted_dir),
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args():
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated')
parser.add_argument('--language', required=True, type=str, help='Dataset language to use')
parser = get_importers_parser(
description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
)
parser.add_argument(dest="target_dir")
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--skiplist",
type=str,
default="",
help="Directories / books to skip, comma separated",
)
parser.add_argument(
"--language", required=True, type=str, help="Dataset language to use"
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label):
if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:

View File

@ -1,30 +1,24 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import pandas
import os
import tarfile
import wave
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
def is_file_truncated(wav_filename, wav_filesize):
with wave.open(wav_filename, mode='rb') as fin:
with wave.open(wav_filename, mode="rb") as fin:
assert fin.getframerate() == 16000
assert fin.getsampwidth() == 2
assert fin.getnchannels() == 1
@ -37,8 +31,13 @@ def is_file_truncated(wav_filename, wav_filesize):
def preprocess_data(folder_with_archives, target_dir):
# First extract subset archives
for subset in ('train', 'dev', 'test'):
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir)
for subset in ("train", "dev", "test"):
extract(
os.path.join(
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
),
target_dir,
)
# Folder structure is now:
# - magicdata_{train,dev,test}.tar.gz
@ -54,58 +53,73 @@ def preprocess_data(folder_with_archives, target_dir):
# name, one containing the speaker ID, and one containing the transcription
def load_set(set_path):
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0)
glob_path = os.path.join(set_path, '*', '*.wav')
transcripts = pandas.read_csv(
os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
)
glob_path = os.path.join(set_path, "*", "*.wav")
set_files = []
for wav in glob.glob(glob_path):
try:
wav_filename = wav
wav_filesize = os.path.getsize(wav)
transcript_key = os.path.basename(wav)
transcript = transcripts.loc[transcript_key, 'Transcription']
transcript = transcripts.loc[transcript_key, "Transcription"]
# Some files in this dataset are truncated, the header duration
# doesn't match the file size. This causes errors at training
# time, so check here if things are fine before including a file
if is_file_truncated(wav_filename, wav_filesize):
print('Warning: File {} is corrupted, header duration does '
'not match file size. Ignoring.'.format(wav_filename))
print(
"Warning: File {} is corrupted, header duration does "
"not match file size. Ignoring.".format(wav_filename)
)
continue
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
for subset in ('train', 'dev', 'test'):
print('Loading {} set samples...'.format(subset))
for subset in ("train", "dev", "test"):
print("Loading {} set samples...".format(subset))
subset_files = load_set(os.path.join(target_dir, subset))
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
# Trim train set to under 10s
if subset == 'train':
durations = (df['wav_filesize'] - 44) / 16000 / 2
if subset == "train":
durations = (df["wav_filesize"] - 44) / 16000 / 2
df = df[durations <= 10.0]
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]')
df = df[~with_noise]
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise)))
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset))
print('Saving {} set into {}...'.format(subset, dest_csv))
with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
df = df[~with_noise]
print(
"Trimming {} samples with noise ([FIL] or [SPK])".format(
sum(with_noise)
)
)
dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
print("Saving {} set into {}...".format(subset, dest_csv))
df.to_csv(dest_csv, index=False)
def main():
# https://openslr.org/68/
parser = get_importers_parser(description='Import MAGICDATA corpus')
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
parser = get_importers_parser(description="Import MAGICDATA corpus")
parser.add_argument(
"folder_with_archives",
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
)
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
)
params = parser.parse_args()
if not params.target_dir:
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata')
params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
preprocess_data(params.folder_with_archives, params.target_dir)

View File

@ -1,25 +1,19 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
import glob
import json
import numpy as np
import pandas
import os
import tarfile
import numpy as np
import pandas
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
def extract(archive_path, target_dir):
print('Extracting {} into {}...'.format(archive_path, target_dir))
print("Extracting {} into {}...".format(archive_path, target_dir))
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -27,7 +21,7 @@ def extract(archive_path, target_dir):
def preprocess_data(tgz_file, target_dir):
# First extract main archive and sub-archives
extract(tgz_file, target_dir)
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1')
main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
# Folder structure is now:
# - primewords_md_2018_set1/
@ -35,14 +29,11 @@ def preprocess_data(tgz_file, target_dir):
# - [0-f]/[00-0f]/*.wav
# - set1_transcript.json
transcripts_path = os.path.join(main_folder, 'set1_transcript.json')
transcripts_path = os.path.join(main_folder, "set1_transcript.json")
with open(transcripts_path) as fin:
transcripts = json.load(fin)
transcripts = {
entry['file']: entry['text']
for entry in transcripts
}
transcripts = {entry["file"]: entry["text"] for entry in transcripts}
def load_set(glob_path):
set_files = []
@ -54,13 +45,13 @@ def preprocess_data(tgz_file, target_dir):
transcript = transcripts[transcript_key]
set_files.append((wav_filename, wav_filesize, transcript))
except KeyError:
print('Warning: Missing transcript for WAV file {}.'.format(wav))
print("Warning: Missing transcript for WAV file {}.".format(wav))
return set_files
# Load all files, then deterministically split into train/dev/test sets
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav'))
all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
df.sort_values(by='wav_filename', inplace=True)
df.sort_values(by="wav_filename", inplace=True)
indices = np.arange(0, len(df))
np.random.seed(12345)
@ -73,29 +64,33 @@ def preprocess_data(tgz_file, target_dir):
train_indices = indices[:-10000]
train_files = df.iloc[train_indices]
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
train_files = train_files[durations <= 15.0]
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum()))
dest_csv = os.path.join(target_dir, 'primewords_train.csv')
print('Saving train set into {}...'.format(dest_csv))
print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
dest_csv = os.path.join(target_dir, "primewords_train.csv")
print("Saving train set into {}...".format(dest_csv))
train_files.to_csv(dest_csv, index=False)
dev_files = df.iloc[dev_indices]
dest_csv = os.path.join(target_dir, 'primewords_dev.csv')
print('Saving dev set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "primewords_dev.csv")
print("Saving dev set into {}...".format(dest_csv))
dev_files.to_csv(dest_csv, index=False)
test_files = df.iloc[test_indices]
dest_csv = os.path.join(target_dir, 'primewords_test.csv')
print('Saving test set into {}...'.format(dest_csv))
dest_csv = os.path.join(target_dir, "primewords_test.csv")
print("Saving test set into {}...".format(dest_csv))
test_files.to_csv(dest_csv, index=False)
def main():
# https://www.openslr.org/47/
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
parser.add_argument(
"--target_dir",
default="",
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
)
params = parser.parse_args()
if not params.target_dir:

View File

@ -1,45 +1,39 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import os
import re
import sox
import zipfile
import subprocess
import progressbar
import unicodedata
import tarfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import unicodedata
import zipfile
from glob import glob
from multiprocessing import Pool
from util.downloader import maybe_download
from util.text import Alphabet
from util.helpers import secs_to_hours
import progressbar
import sox
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
from deepspeech_training.util.text import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 15
ARCHIVE_DIR_NAME = 'African_Accented_French'
ARCHIVE_NAME = 'African_Accented_French.tar.gz'
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
ARCHIVE_DIR_NAME = "African_Accented_French"
ARCHIVE_NAME = "African_Accented_French.tar.gz"
ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract data
@ -47,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
# Produce CSV files
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -60,81 +55,89 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
else:
print('Found directory "%s" - not extracting it from archive.' % archive_path)
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
wav_filename = sample[0]
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = label_filter(sample[1])
counter = get_counter()
rows = []
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
target_csv_template = os.path.join(
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
)
if os.path.isfile(target_csv_template):
return
wav_root_dir = os.path.join(extracted_dir)
all_files = [
'transcripts/train/yaounde/fn_text.txt',
'transcripts/train/ca16_conv/transcripts.txt',
'transcripts/train/ca16_read/conditioned.txt',
'transcripts/dev/niger_west_african_fr/transcripts.txt',
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv',
'transcripts/devtest/ca16_read/conditioned.txt',
'transcripts/test/ca16/prompts.txt',
"transcripts/train/yaounde/fn_text.txt",
"transcripts/train/ca16_conv/transcripts.txt",
"transcripts/train/ca16_read/conditioned.txt",
"transcripts/dev/niger_west_african_fr/transcripts.txt",
"speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
"transcripts/devtest/ca16_read/conditioned.txt",
"transcripts/test/ca16/prompts.txt",
]
transcripts = {}
for tr in all_files:
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source:
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
for line in tr_source.readlines():
line = line.strip()
if '.tsv' in tr:
sep = ' '
if ".tsv" in tr:
sep = " "
else:
sep = ' '
sep = " "
audio = os.path.basename(line.split(sep)[0])
if not ('.wav' in audio):
if '.tdf' in audio:
audio = audio.replace('.tdf', '.wav')
if not (".wav" in audio):
if ".tdf" in audio:
audio = audio.replace(".tdf", ".wav")
else:
audio += '.wav'
audio += ".wav"
transcript = ' '.join(line.split(sep)[1:])
transcript = " ".join(line.split(sep)[1:])
transcripts[audio] = transcript
# Get audiofile path and transcript for each sentence in tsv
samples = []
glob_dir = os.path.join(wav_root_dir, '**/*.wav')
glob_dir = os.path.join(wav_root_dir, "**/*.wav")
for record in glob(glob_dir, recursive=True):
record_file = os.path.basename(record)
if record_file in transcripts:
@ -156,9 +159,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -178,25 +181,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def handle_args():
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
parser.add_argument(dest='target_dir')
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser = get_importers_parser(
description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
)
parser.add_argument(dest="target_dir")
parser.add_argument(
"--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
@ -204,9 +220,11 @@ if __name__ == "__main__":
def label_filter(label):
if CLI_ARGS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \
.encode("ascii", "ignore") \
label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = validate_label(label)
if ALPHABET and label:
try:

View File

@ -1,44 +1,38 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
# ./data/swb/swb1_LDC97S62.tgz
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import fnmatch
import pandas
import os
import subprocess
import sys
import tarfile
import unicodedata
import wave
import codecs
import tarfile
import requests
from util.importers import validate_label_eng as validate_label
import librosa
import soundfile # <= Has an external dependency on libsndfile
import pandas
import requests
import soundfile # <= Has an external dependency on libsndfile
from deepspeech_training.util.importers import validate_label_eng as validate_label
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
ARCHIVE_URL = 'http://www.openslr.org/resources/5/'
ARCHIVE_DIR_NAME = 'LDC97S62'
LDC_DATASET = 'swb1_LDC97S62.tgz'
ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
ARCHIVE_URL = "http://www.openslr.org/resources/5/"
ARCHIVE_DIR_NAME = "LDC97S62"
LDC_DATASET = "swb1_LDC97S62.tgz"
def download_file(folder, url):
# https://stackoverflow.com/a/16696317/738515
local_filename = url.split('/')[-1]
local_filename = url.split("/")[-1]
full_filename = os.path.join(folder, local_filename)
r = requests.get(url, stream=True)
with open(full_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
with open(full_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
return full_filename
@ -46,7 +40,7 @@ def download_file(folder, url):
def maybe_download(archive_url, target_dir, ldc_dataset):
# If archive file does not exist, download it...
archive_path = os.path.join(target_dir, ldc_dataset)
ldc_path = archive_url+ldc_dataset
ldc_path = archive_url + ldc_dataset
if not os.path.exists(target_dir):
print('No path "%s" - creating ...' % target_dir)
makedirs(target_dir)
@ -65,17 +59,23 @@ def _download_and_preprocess_data(data_dir):
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
# Check swb1_LDC97S62.tgz then extract
assert(os.path.isfile(archive_path))
assert os.path.isfile(archive_path)
_extract(target_dir, archive_path)
# Transcripts
transcripts_path = maybe_download(ARCHIVE_URL, target_dir, ARCHIVE_NAME)
_extract(target_dir, transcripts_path)
# Check swb1_d1/2/3/4/swb_ms98_transcriptions
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"]
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders]))
expected_folders = [
"swb1_d1",
"swb1_d2",
"swb1_d3",
"swb1_d4",
"swb_ms98_transcriptions",
]
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
# Conditionally convert swb sph data to wav
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
_maybe_convert_wav(target_dir, "swb1_d2", "swb1_d2-wav")
@ -83,13 +83,21 @@ def _download_and_preprocess_data(data_dir):
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
# Conditionally split wav data
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav")
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav")
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav")
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav")
d1 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
)
d2 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
)
d3 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
)
d4 = _maybe_split_wav_and_sentences(
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
)
swb_files = d1.append(d2).append(d3).append(d4)
train_files, dev_files, test_files = _split_sets(swb_files)
# Write sets to disk as CSV files
@ -97,7 +105,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(os.path.join(target_dir, "swb-dev.csv"), index=False)
test_files.to_csv(os.path.join(target_dir, "swb-test.csv"), index=False)
def _extract(target_dir, archive_path):
with tarfile.open(archive_path) as tar:
tar.extractall(target_dir)
@ -118,25 +126,46 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
# Loop over sph files in source_dir and convert each to 16-bit PCM wav
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.sph"):
for channel in ['1', '2']:
for channel in ["1", "2"]:
sph_file = os.path.join(root, filename)
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav"
wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(target_dir, wav_filename)
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav"
temp_wav_filename = (
os.path.splitext(os.path.basename(sph_file))[0]
+ "-"
+ channel
+ "-temp.wav"
)
temp_wav_file = os.path.join(target_dir, temp_wav_filename)
print("converting {} to {}".format(sph_file, temp_wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file])
subprocess.check_call(
[
"sph2pipe",
"-c",
channel,
"-p",
"-f",
"rif",
sph_file,
temp_wav_file,
]
)
print("upsampling {} to {}".format(temp_wav_file, wav_file))
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
soundfile.write(wav_file, audioData, frameRate, "PCM_16")
os.remove(temp_wav_file)
def _parse_transcriptions(trans_file):
segments = []
with codecs.open(trans_file, "r", "utf-8") as fin:
for line in fin:
if line.startswith("#") or len(line) <= 1:
if line.startswith("#") or len(line) <= 1:
continue
tokens = line.split()
@ -150,15 +179,19 @@ def _parse_transcriptions(trans_file):
# We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array
# expects a string.
transcript = unicodedata.normalize("NFKD", transcript) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
transcript = (
unicodedata.normalize("NFKD", transcript)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
segments.append({
"start_time": start_time,
"stop_time": stop_time,
"transcript": transcript,
})
segments.append(
{
"start_time": start_time,
"stop_time": stop_time,
"transcript": transcript,
}
)
return segments
@ -183,8 +216,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A']
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav"
channel = ("2", "1")[
(os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
]
wav_filename = (
"sw0"
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
+ "-"
+ channel
+ ".wav"
)
wav_file = os.path.join(source_dir, wav_filename)
print("splitting {} according to {}".format(wav_file, trans_file))
@ -200,26 +241,39 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
# Create wav segment filename
start_time = segment["start_time"]
stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(
start_time) + "-" + str(stop_time) + ".wav"
new_wav_filename = (
os.path.splitext(os.path.basename(trans_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
if _is_wav_too_short(new_wav_filename):
continue
continue
new_wav_file = os.path.join(target_dir, new_wav_filename)
_split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = os.path.getsize(new_wav_file)
transcript = segment["transcript"]
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
files.append(
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
)
# Close origAudio
origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _is_wav_too_short(wav_filename):
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav']
short_wav_filenames = [
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
]
return wav_filename in short_wav_filenames
@ -234,7 +288,7 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
chunkAudio.writeframes(chunkData)
chunkAudio.close()
def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
@ -248,10 +302,24 @@ def _split_sets(filelist):
test_beg = dev_end
test_end = len(filelist)
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end])
return (
filelist[train_beg:train_end],
filelist[dev_beg:dev_end],
filelist[test_beg:test_end],
)
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0):
def _read_data_set(
filelist,
thread_count,
batch_size,
numcep,
numcontext,
stride=1,
offset=0,
next_index=lambda i: i + 1,
limit=0,
):
# Optionally apply dataset size limit
if limit > 0:
filelist = filelist.iloc[:limit]
@ -259,7 +327,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
filelist = filelist[offset::stride]
# Return DataSet
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index)
return DataSet(
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
)
if __name__ == "__main__":

View File

@ -1,70 +1,76 @@
#!/usr/bin/env python
'''
"""
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
Use "python3 import_swc.py -h" for help
'''
from __future__ import absolute_import, division, print_function
"""
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import re
import csv
import sox
import wave
import shutil
import random
import tarfile
import argparse
import progressbar
import csv
import os
import random
import re
import shutil
import sys
import tarfile
import unicodedata
import wave
import xml.etree.cElementTree as ET
from os import path
from glob import glob
from collections import Counter
from glob import glob
from multiprocessing.pool import ThreadPool
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
import progressbar
import sox
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar"
LANGUAGES = ['dutch', 'english', 'german']
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker']
LANGUAGES = ["dutch", "english", "german"]
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
CHANNELS = 1
SAMPLE_RATE = 16000
UNKNOWN = '<unknown>'
AUDIO_PATTERN = 'audio*.ogg'
WAV_NAME = 'audio.wav'
ALIGNED_NAME = 'aligned.swc'
UNKNOWN = "<unknown>"
AUDIO_PATTERN = "audio*.ogg"
WAV_NAME = "audio.wav"
ALIGNED_NAME = "aligned.swc"
SUBSTITUTIONS = {
'german': [
(re.compile(r'\$'), 'dollar'),
(re.compile(r''), 'euro'),
(re.compile(r'£'), 'pfund'),
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
(re.compile(r'eins punkt null null null'), 'ein tausend'),
(re.compile(r'punkt null null null'), 'tausend'),
(re.compile(r'punkt null'), None)
"german": [
(re.compile(r"\$"), "dollar"),
(re.compile(r""), "euro"),
(re.compile(r"£"), "pfund"),
(
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
r"\1zehnhundert \2er ",
),
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
(
re.compile(
r"eins punkt null null null punkt null null null punkt null null null"
),
"eine milliarde",
),
(
re.compile(
r"punkt null null null punkt null null null punkt null null null"
),
"milliarden",
),
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
(re.compile(r"eins punkt null null null"), "ein tausend"),
(re.compile(r"punkt null null null"), "tausend"),
(re.compile(r"punkt null"), None),
]
}
DONT_NORMALIZE = {
'german': 'ÄÖÜäöüß'
}
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
class Sample:
@ -98,11 +104,14 @@ def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
(margin_of_error ** 2 * train_size)
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
@ -111,14 +120,16 @@ def get_sample_size(population_size):
def maybe_download_language(language):
lang_upper = language[0].upper() + language[1:]
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper))
return maybe_download(
SWC_ARCHIVE.format(language=lang_upper),
CLI_ARGS.base_dir,
SWC_URL.format(language=lang_upper),
)
def maybe_extract(data_dir, extracted_data, archive):
extracted = path.join(data_dir, extracted_data)
if path.isdir(extracted):
extracted = os.path.join(data_dir, extracted_data)
if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
@ -133,29 +144,29 @@ def maybe_extract(data_dir, extracted_data, archive):
def ignored(node):
if node is None:
return False
if node.tag == 'ignored':
if node.tag == "ignored":
return True
return ignored(node.find('..'))
return ignored(node.find(".."))
def read_token(token):
texts, start, end = [], None, None
notes = token.findall('n')
notes = token.findall("n")
if len(notes) > 0:
for note in notes:
attributes = note.attrib
if start is None and 'start' in attributes:
start = int(attributes['start'])
if 'end' in attributes:
token_end = int(attributes['end'])
if start is None and "start" in attributes:
start = int(attributes["start"])
if "end" in attributes:
token_end = int(attributes["end"])
if end is None or token_end > end:
end = token_end
if 'pronunciation' in attributes:
t = attributes['pronunciation']
if "pronunciation" in attributes:
t = attributes["pronunciation"]
texts.append(t)
elif 'text' in token.attrib:
texts.append(token.attrib['text'])
return start, end, ' '.join(texts)
elif "text" in token.attrib:
texts.append(token.attrib["text"])
return start, end, " ".join(texts)
def in_alphabet(alphabet, c):
@ -163,10 +174,12 @@ def in_alphabet(alphabet, c):
ALPHABETS = {}
def get_alphabet(language):
if language in ALPHABETS:
return ALPHABETS[language]
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
alphabet = Alphabet(alphabet_path) if alphabet_path else None
ALPHABETS[language] = alphabet
return alphabet
@ -176,27 +189,35 @@ def label_filter(label, language):
label = label.translate(PRE_FILTER)
label = validate_label(label)
if label is None:
return None, 'validation'
return None, "validation"
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
for pattern, replacement in substitutions:
if replacement is None:
if pattern.match(label):
return None, 'substitution rule'
return None, "substitution rule"
else:
label = pattern.sub(replacement, label)
chars = []
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language)
for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
if (
CLI_ARGS.normalize
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c:
if not in_alphabet(alphabet, sc):
return None, 'illegal character'
return None, "illegal character"
chars.append(sc)
label = ''.join(chars)
label = "".join(chars)
label = validate_label(label)
return label, 'validation' if label is None else None
return label, "validation" if label is None else None
def collect_samples(base_dir, language):
@ -207,7 +228,9 @@ def collect_samples(base_dir, language):
samples = []
reasons = Counter()
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'):
def add_sample(
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
):
if p_start is not None and p_end is not None and p_text is not None:
duration = p_end - p_start
text, filter_reason = label_filter(p_text, language)
@ -217,53 +240,67 @@ def collect_samples(base_dir, language):
p_reason = filter_reason
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
skip = True
p_reason = 'unknown speaker'
p_reason = "unknown speaker"
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
skip = True
p_reason = 'unknown article'
p_reason = "unknown article"
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
skip = True
p_reason = 'exceeded duration'
p_reason = "exceeded duration"
elif int(duration / 30) < len(text):
skip = True
p_reason = 'too short to decode'
p_reason = "too short to decode"
elif duration / len(text) < 10:
skip = True
p_reason = 'length duration ratio'
p_reason = "length duration ratio"
if skip:
reasons[p_reason] += 1
else:
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker))
samples.append(
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
)
elif p_start is None or p_end is None:
reasons['missing timestamps'] += 1
reasons["missing timestamps"] += 1
else:
reasons['missing text'] += 1
reasons["missing text"] += 1
print('Collecting samples...')
print("Collecting samples...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
for root in bar(roots):
wav_path = path.join(root, WAV_NAME)
wav_path = os.path.join(root, WAV_NAME)
aligned = ET.parse(path.join(root, ALIGNED_NAME))
article = UNKNOWN
speaker = UNKNOWN
for prop in aligned.iter('prop'):
for prop in aligned.iter("prop"):
attributes = prop.attrib
if 'key' in attributes and 'value' in attributes:
if attributes['key'] == 'DC.identifier':
article = attributes['value']
elif attributes['key'] == 'reader.name':
speaker = attributes['value']
for sentence in aligned.iter('s'):
if "key" in attributes and "value" in attributes:
if attributes["key"] == "DC.identifier":
article = attributes["value"]
elif attributes["key"] == "reader.name":
speaker = attributes["value"]
for sentence in aligned.iter("s"):
if ignored(sentence):
continue
split = False
tokens = list(map(read_token, sentence.findall('t')))
tokens = list(map(read_token, sentence.findall("t")))
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
for token_start, token_end, token_text in tokens:
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='has numbers')
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="has numbers",
)
sample_start, sample_end, token_texts, sample_texts = (
None,
None,
[],
[],
)
continue
if sample_start is None:
sample_start = token_start
@ -271,20 +308,37 @@ def collect_samples(base_dir, language):
continue
token_texts.append(token_text)
if token_end is not None:
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split')
if (
token_start != sample_start
and token_end - sample_start > CLI_ARGS.max_duration > 0
):
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split",
)
sample_start = sample_end
sample_texts = []
split = True
sample_end = token_end
sample_texts.extend(token_texts)
token_texts = []
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
p_reason='split' if split else 'complete')
print('Skipped samples:')
add_sample(
wav_path,
article,
speaker,
sample_start,
sample_end,
" ".join(sample_texts),
p_reason="split" if split else "complete",
)
print("Skipped samples:")
for reason, n in reasons.most_common():
print(' - {}: {}'.format(reason, n))
print(" - {}: {}".format(reason, n))
return samples
@ -294,8 +348,8 @@ def maybe_convert_one_to_wav(entry):
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
combiner = sox.Combiner()
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
output_wav = path.join(root, WAV_NAME)
if path.isfile(output_wav):
output_wav = os.path.join(root, WAV_NAME)
if os.path.isfile(output_wav):
return
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
try:
@ -304,18 +358,18 @@ def maybe_convert_one_to_wav(entry):
elif len(files) > 1:
wav_files = []
for i, file in enumerate(files):
wav_path = path.join(root, 'audio{}.wav'.format(i))
wav_path = os.path.join(root, "audio{}.wav".format(i))
transformer.build(file, wav_path)
wav_files.append(wav_path)
combiner.set_input_format(file_type=['wav'] * len(wav_files))
combiner.build(wav_files, output_wav, 'concatenate')
combiner.set_input_format(file_type=["wav"] * len(wav_files))
combiner.build(wav_files, output_wav, "concatenate")
except sox.core.SoxError:
return
def maybe_convert_to_wav(base_dir):
roots = list(os.walk(base_dir))
print('Converting and joining source audio files...')
print("Converting and joining source audio files...")
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
tp = ThreadPool()
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
@ -335,53 +389,66 @@ def assign_sub_sets(samples):
sample_set.extend(speakers.pop(0))
train_set = sum(speakers, [])
if len(train_set) == 0:
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
print(
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
)
random.seed(42) # same source data == same output
random.shuffle(samples)
for index, sample in enumerate(samples):
if index < sample_size:
sample.sub_set = 'dev'
sample.sub_set = "dev"
elif index < 2 * sample_size:
sample.sub_set = 'test'
sample.sub_set = "test"
else:
sample.sub_set = 'train'
sample.sub_set = "train"
else:
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
for sub_set, sub_set_samples in [
("train", train_set),
("dev", sample_sets[0]),
("test", sample_sets[1]),
]:
for sample in sub_set_samples:
sample.sub_set = sub_set
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
.format(sub_set, len(sub_set_samples), t))
print(
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
sub_set, len(sub_set_samples), t
)
)
def create_sample_dirs(language):
print('Creating sample directories...')
for set_name in ['train', 'dev', 'test']:
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
if not path.isdir(dir_path):
print("Creating sample directories...")
for set_name in ["train", "dev", "test"]:
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
if not os.path.isdir(dir_path):
os.mkdir(dir_path)
def split_audio_files(samples, language):
print('Splitting audio files...')
print("Splitting audio files...")
sub_sets = Counter()
src_wav_files = group(samples, lambda s: s.wav_path).items()
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
for wav_path, file_samples in bar(src_wav_files):
file_samples = sorted(file_samples, key=lambda s: s.start)
with wave.open(wav_path, 'r') as src_wav_file:
with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate()
for sample in file_samples:
index = sub_sets[sample.sub_set]
sample_wav_path = path.join(CLI_ARGS.base_dir,
language + '-' + sample.sub_set,
'sample-{0:06d}.wav'.format(index))
sample_wav_path = os.path.join(
CLI_ARGS.base_dir,
language + "-" + sample.sub_set,
"sample-{0:06d}.wav".format(index),
)
sample.wav_path = sample_wav_path
sub_sets[sample.sub_set] += 1
src_wav_file.setpos(int(sample.start * rate / 1000.0))
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
with wave.open(sample_wav_path, 'w') as sample_wav_file:
data = src_wav_file.readframes(
int((sample.end - sample.start) * rate / 1000.0)
)
with wave.open(sample_wav_path, "w") as sample_wav_file:
sample_wav_file.setnchannels(src_wav_file.getnchannels())
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
sample_wav_file.setframerate(rate)
@ -391,22 +458,26 @@ def split_audio_files(samples, language):
def write_csvs(samples, language):
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
base_dir = path.abspath(CLI_ARGS.base_dir)
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
base_dir = os.path.abspath(CLI_ARGS.base_dir)
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
)
writer.writeheader()
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
bar = progressbar.ProgressBar(
max_value=len(set_samples), widgets=SIMPLE_BAR
)
for sample in bar(set_samples):
row = {
'wav_filename': path.relpath(sample.wav_path, base_dir),
'wav_filesize': path.getsize(sample.wav_path),
'transcript': sample.text
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
"wav_filesize": os.path.getsize(sample.wav_path),
"transcript": sample.text,
}
if CLI_ARGS.add_meta:
row['article'] = sample.article
row['speaker'] = sample.speaker
row["article"] = sample.article
row["speaker"] = sample.speaker
writer.writerow(row)
@ -414,8 +485,8 @@ def cleanup(archive, language):
if not CLI_ARGS.keep_archive:
print('Removing archive "{}"...'.format(archive))
os.remove(archive)
language_dir = path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
language_dir = os.path.join(CLI_ARGS.base_dir, language)
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
print('Removing intermediate files in "{}"...'.format(language_dir))
shutil.rmtree(language_dir)
@ -433,34 +504,75 @@ def prepare_language(language):
def handle_args():
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
parser.add_argument('--exclude_numbers', type=bool, default=True,
help='If sequences with non-transliterated numbers should be excluded')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--ignore_too_long', type=bool, default=False,
help='If samples exceeding max_duration should be removed')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument(
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
)
parser.add_argument(
"--exclude_numbers",
type=bool,
default=True,
help="If sequences with non-transliterated numbers should be excluded",
)
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--ignore_too_long",
type=bool,
default=False,
help="If samples exceeding max_duration should be removed",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
for language in LANGUAGES:
parser.add_argument('--{}_alphabet'.format(language),
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns')
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers')
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles')
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
parser.add_argument('--keep_intermediate', type=bool, default=False,
help='If intermediate files should be kept')
parser.add_argument(
"--{}_alphabet".format(language),
help="Exclude {} samples with characters not in provided alphabet file".format(
language
),
)
parser.add_argument(
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
)
parser.add_argument(
"--exclude_unknown_speakers",
action="store_true",
help="Exclude unknown speakers",
)
parser.add_argument(
"--exclude_unknown_articles",
action="store_true",
help="Exclude unknown articles",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
parser.add_argument(
"--keep_intermediate",
type=bool,
default=False,
help="If intermediate files should be kept",
)
return parser.parse_args()
if __name__ == "__main__":
CLI_ARGS = handle_args()
if CLI_ARGS.language == 'all':
if CLI_ARGS.language == "all":
for lang in LANGUAGES:
prepare_language(lang)
elif CLI_ARGS.language in LANGUAGES:
prepare_language(CLI_ARGS.language)
else:
fail('Wrong language id')
fail("Wrong language id")

View File

@ -1,24 +1,18 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import codecs
import pandas
import tarfile
import unicodedata
import wave
from glob import glob
from os import makedirs, path, remove, rmdir
import pandas
from sox import Transformer
from util.downloader import maybe_download
from tensorflow.python.platform import gfile
from util.stm import parse_stm_file
from deepspeech_training.util.downloader import maybe_download
from deepspeech_training.util.stm import parse_stm_file
def _download_and_preprocess_data(data_dir):
# Conditionally download data
@ -41,6 +35,7 @@ def _download_and_preprocess_data(data_dir):
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
def _maybe_extract(data_dir, extracted_data, archive):
# If data_dir/extracted_data does not exist, extract archive in data_dir
if not gfile.Exists(path.join(data_dir, extracted_data)):
@ -48,6 +43,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
tar.extractall(data_dir)
tar.close()
def _maybe_convert_wav(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
@ -61,6 +57,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
# Conditionally convert test sph to wav
_maybe_convert_wav_dataset(extracted_dir, "test")
def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Create source dir
source_dir = path.join(extracted_dir, data_set, "sph")
@ -84,6 +81,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
# Remove source_dir
rmdir(source_dir)
def _maybe_split_sentences(data_dir, extracted_data):
# Create extracted_data dir
extracted_dir = path.join(data_dir, extracted_data)
@ -99,6 +97,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
return train_files, dev_files, test_files
def _maybe_split_dataset(extracted_dir, data_set):
# Create stm dir
stm_dir = path.join(extracted_dir, data_set, "stm")
@ -116,14 +115,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
# Open wav corresponding to stm_file
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
wav_file = path.join(wav_dir, wav_filename)
origAudio = wave.open(wav_file,'r')
origAudio = wave.open(wav_file, "r")
# Loop over stm_segments and split wav_file for each segment
for stm_segment in stm_segments:
# Create wav segment filename
start_time = stm_segment.start_time
stop_time = stm_segment.stop_time
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
new_wav_filename = (
path.splitext(path.basename(stm_file))[0]
+ "-"
+ str(start_time)
+ "-"
+ str(stop_time)
+ ".wav"
)
new_wav_file = path.join(wav_dir, new_wav_filename)
# If the wav segment filename does not exist create it
@ -131,23 +137,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
_split_wav(origAudio, start_time, stop_time, new_wav_file)
new_wav_filesize = path.getsize(new_wav_file)
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript))
files.append(
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
)
# Close origAudio
origAudio.close()
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio.getframerate()
origAudio.setpos(int(start_time*frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate))
chunkAudio = wave.open(new_wav_file,'w')
origAudio.setpos(int(start_time * frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
chunkAudio = wave.open(new_wav_file, "w")
chunkAudio.setnchannels(origAudio.getnchannels())
chunkAudio.setsampwidth(origAudio.getsampwidth())
chunkAudio.setframerate(frameRate)
chunkAudio.writeframes(chunkData)
chunkAudio.close()
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
'''
"""
NAME : LDC TIMIT Dataset
URL : https://catalog.ldc.upenn.edu/ldc93s1
HOURS : 5
@ -8,29 +8,32 @@
AUTHORS : Garofolo, John, et al.
TYPE : LDC Membership
LICENCE : LDC User Agreement
'''
"""
import errno
import fnmatch
import os
from os import path
import subprocess
import sys
import tarfile
import fnmatch
from os import path
import pandas as pd
import subprocess
def clean(word):
# LC ALL & strip punctuation which are not required
new = word.lower().replace('.', '')
new = new.replace(',', '')
new = new.replace(';', '')
new = new.replace('"', '')
new = new.replace('!', '')
new = new.replace('?', '')
new = new.replace(':', '')
new = new.replace('-', '')
new = word.lower().replace(".", "")
new = new.replace(",", "")
new = new.replace(";", "")
new = new.replace('"', "")
new = new.replace("!", "")
new = new.replace("?", "")
new = new.replace(":", "")
new = new.replace("-", "")
return new
def _preprocess_data(args):
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
@ -40,16 +43,24 @@ def _preprocess_data(args):
if ignoreSASentences:
print("Using recommended ignore SA sentences")
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)")
print(
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
)
else:
print("Using unrecommended setting to include SA sentences")
datapath = args
target = path.join(datapath, "TIMIT")
print("Checking to see if data has already been extracted in given argument: %s", target)
print(
"Checking to see if data has already been extracted in given argument: %s",
target,
)
if not path.isdir(target):
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath)
print(
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
datapath,
)
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
if path.isfile(filepath):
print("File found, extracting")
@ -103,40 +114,58 @@ def _preprocess_data(args):
# if ignoreSAsentences we only want those without SA in the name
# OR
# if not ignoreSAsentences we want all to be added
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences):
if 'train' in full_wav.lower():
if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
not ignoreSASentences
):
if "train" in full_wav.lower():
train_list_wavs.append(full_wav)
train_list_trans.append(trans)
train_list_size.append(wav_filesize)
elif 'test' in full_wav.lower():
elif "test" in full_wav.lower():
test_list_wavs.append(full_wav)
test_list_trans.append(trans)
test_list_size.append(wav_filesize)
else:
raise IOError
a = {'wav_filename': train_list_wavs,
'wav_filesize': train_list_size,
'transcript': train_list_trans
}
a = {
"wav_filename": train_list_wavs,
"wav_filesize": train_list_size,
"transcript": train_list_trans,
}
c = {'wav_filename': test_list_wavs,
'wav_filesize': test_list_size,
'transcript': test_list_trans
}
c = {
"wav_filename": test_list_wavs,
"wav_filesize": test_list_size,
"transcript": test_list_trans,
}
all = {'wav_filename': train_list_wavs + test_list_wavs,
'wav_filesize': train_list_size + test_list_size,
'transcript': train_list_trans + test_list_trans
}
all = {
"wav_filename": train_list_wavs + test_list_wavs,
"wav_filesize": train_list_size + test_list_size,
"transcript": train_list_trans + test_list_trans,
}
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
df_all = pd.DataFrame(
all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_train = pd.DataFrame(
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_test = pd.DataFrame(
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
)
df_all.to_csv(
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_train.to_csv(
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_test.to_csv(
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
)
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
if __name__ == "__main__":
_preprocess_data(sys.argv[1])

View File

@ -1,52 +1,53 @@
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import csv
import os
import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
import csv
import unidecode
import zipfile
import sox
import subprocess
import progressbar
import zipfile
from multiprocessing import Pool
from util.downloader import SIMPLE_BAR
from os import path
import progressbar
import sox
from util.downloader import maybe_download
import unidecode
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
get_importers_parser,
get_validate_label,
print_import_report,
)
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000
MAX_SECS = 15
ARCHIVE_NAME = '2019-04-11_fr_FR'
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip'
ARCHIVE_NAME = "2019-04-11_fr_FR"
ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
ARCHIVE_URL = (
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
)
def _download_and_preprocess_data(target_dir, english_compatible=False):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
archive_path = maybe_download(
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
)
# Conditionally extract archive data
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible)
_maybe_convert_sets(
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
)
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print('No directory "%s" - extracting archive...' % extracted_path)
if not os.path.isdir(extracted_path):
os.mkdir(extracted_path)
@ -58,16 +59,20 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def one_sample(sample):
""" Take a audio file, and optionally convert it to 16kHz WAV """
orig_filename = sample['path']
orig_filename = sample["path"]
# Storing wav files next to the wav ones - just with a different suffix
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
_maybe_convert_wav(orig_filename, wav_filename)
file_size = -1
frames = 0
if path.exists(wav_filename):
file_size = path.getsize(wav_filename)
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
label = sample['text']
if os.path.exists(wav_filename):
file_size = os.path.getsize(wav_filename)
frames = int(
subprocess.check_output(
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
)
)
label = sample["text"]
rows = []
@ -75,40 +80,41 @@ def one_sample(sample):
counter = get_counter()
if file_size == -1:
# Excluding samples that failed upon conversion
counter['failed'] += 1
counter["failed"] += 1
elif label is None:
# Excluding samples that failed on label validation
counter['invalid_label'] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
counter["invalid_label"] += 1
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript
counter['too_short'] += 1
elif frames/SAMPLE_RATE > MAX_SECS:
counter["too_short"] += 1
elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size
counter['too_long'] += 1
counter["too_long"] += 1
else:
# This one is good - keep it for the target CSV
rows.append((wav_filename, file_size, label))
counter['all'] += 1
counter['total_time'] += frames
counter["all"] += 1
counter["total_time"] += frames
return (counter, rows)
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
extracted_dir = path.join(target_dir, extracted_data)
extracted_dir = os.path.join(target_dir, extracted_data)
# override existing CSV with normalized one
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
if os.path.isfile(target_csv_template):
return
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
path_to_original_csv = os.path.join(extracted_dir, "data.csv")
with open(path_to_original_csv) as csv_f:
data = [
d for d in csv.DictReader(csv_f, delimiter=',')
if float(d['duration']) <= MAX_SECS
d
for d in csv.DictReader(csv_f, delimiter=",")
if float(d["duration"]) <= MAX_SECS
]
for line in data:
line['path'] = os.path.join(extracted_dir, line['path'])
line["path"] = os.path.join(extracted_dir, line["path"])
num_samples = len(data)
rows = []
@ -125,9 +131,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
pool.close()
pool.join()
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -136,7 +142,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
test_writer.writeheader()
for i, item in enumerate(rows):
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
transcript = validate_label(
cleanup_transcript(
item[2], english_compatible=english_compatible
)
)
if not transcript:
continue
wav_filename = os.path.join(target_dir, extracted_data, item[0])
@ -147,45 +157,53 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
writer = dev_writer
else:
writer = train_writer
writer.writerow(dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
))
writer.writerow(
dict(
wav_filename=wav_filename,
wav_filesize=os.path.getsize(wav_filename),
transcript=transcript,
)
)
imported_samples = get_imported_samples(counter)
assert counter['all'] == num_samples
assert counter["all"] == num_samples
assert len(rows) == imported_samples
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
def _maybe_convert_wav(orig_filename, wav_filename):
if not path.exists(wav_filename):
if not os.path.exists(wav_filename):
transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE)
try:
transformer.build(orig_filename, wav_filename)
except sox.core.SoxError as ex:
print('SoX processing error', ex, orig_filename, wav_filename)
print("SoX processing error", ex, orig_filename, wav_filename)
PUNCTUATIONS_REG = re.compile(r"\-,;!?.()\[\]*…—]")
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}')
MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
def cleanup_transcript(text, english_compatible=False):
text = text.replace('', "'").replace('\u00A0', ' ')
text = PUNCTUATIONS_REG.sub(' ', text)
text = MULTIPLE_SPACES_REG.sub(' ', text)
text = text.replace("", "'").replace("\u00A0", " ")
text = PUNCTUATIONS_REG.sub(" ", text)
text = MULTIPLE_SPACES_REG.sub(" ", text)
if english_compatible:
text = unidecode.unidecode(text)
return text.strip().lower()
def handle_args():
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
parser.add_argument(dest='target_dir')
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
parser.add_argument(dest="target_dir")
parser.add_argument(
"--english-compatible",
action="store_true",
dest="english_compatible",
help="Remove diactrics and other non-ascii chars.",
)
return parser.parse_args()

View File

@ -1,45 +1,40 @@
#!/usr/bin/env python
'''
"""
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
Use "python3 import_tuda.py -h" for help
'''
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import csv
import wave
import tarfile
"""
import argparse
import progressbar
import csv
import os
import tarfile
import unicodedata
import wave
import xml.etree.cElementTree as ET
from os import path
from collections import Counter
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2'
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE)
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE)
import progressbar
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import validate_label_eng as validate_label
from deepspeech_training.util.text import Alphabet
TUDA_VERSION = "v2"
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
TUDA_PACKAGE
)
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
CHANNELS = 1
SAMPLE_WIDTH = 2
SAMPLE_RATE = 16000
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
def maybe_extract(archive):
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if path.isdir(extracted):
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
if os.path.isdir(extracted):
print('Found directory "{}" - not extracting.'.format(extracted))
else:
print('Extracting "{}"...'.format(archive))
@ -52,86 +47,100 @@ def maybe_extract(archive):
def check_and_prepare_sentence(sentence):
sentence = sentence.lower().replace('co2', 'c o zwei')
sentence = sentence.lower().replace("co2", "c o zwei")
chars = []
for c in sentence:
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
if (
CLI_ARGS.normalize
and c not in "äöüß"
and (ALPHABET is None or not ALPHABET.has_char(c))
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c:
if ALPHABET is not None and not ALPHABET.has_char(c):
return None
chars.append(sc)
return validate_label(''.join(chars))
return validate_label("".join(chars))
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
try:
with wave.open(wav_path, 'r') as src_wav_file:
with wave.open(wav_path, "r") as src_wav_file:
rate = src_wav_file.getframerate()
channels = src_wav_file.getnchannels()
sample_width = src_wav_file.getsampwidth()
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
if rate != SAMPLE_RATE:
return False, 'wrong sample rate'
return False, "wrong sample rate"
if channels != CHANNELS:
return False, 'wrong number of channels'
return False, "wrong number of channels"
if sample_width != SAMPLE_WIDTH:
return False, 'wrong sample width'
return False, "wrong sample width"
if milliseconds / len(sentence) < 30:
return False, 'too short'
return False, "too short"
if milliseconds > CLI_ARGS.max_duration > 0:
return False, 'too long'
return False, "too long"
except wave.Error:
return False, 'invalid wav file'
return False, "invalid wav file"
except EOFError:
return False, 'premature EOF'
return True, 'OK'
return False, "premature EOF"
return True, "OK"
def write_csvs(extracted):
sample_counter = 0
reasons = Counter()
for sub_set in ['train', 'dev', 'test']:
set_path = path.join(extracted, sub_set)
for sub_set in ["train", "dev", "test"]:
set_path = os.path.join(extracted, sub_set)
set_files = os.listdir(set_path)
recordings = {}
for file in set_files:
if file.endswith('.xml'):
if file.endswith(".xml"):
recordings[file[:-4]] = []
for file in set_files:
if file.endswith('.wav') and '_' in file:
prefix = file.split('_')[0]
if file.endswith(".wav") and "_" in file:
prefix = file.split("_")[0]
if prefix in recordings:
recordings[prefix].append(file)
recordings = recordings.items()
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
csv_path = os.path.join(
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
)
print('Writing "{}"...'.format(csv_path))
with open(csv_path, 'w') as csv_file:
with open(csv_path, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
writer.writeheader()
set_dir = path.join(extracted, sub_set)
set_dir = os.path.join(extracted, sub_set)
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
for prefix, wav_names in bar(recordings):
xml_path = path.join(set_dir, prefix + '.xml')
xml_path = os.path.join(set_dir, prefix + ".xml")
meta = ET.parse(xml_path).getroot()
sentence = list(meta.iter('cleaned_sentence'))[0].text
sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence)
if sentence is None:
continue
for wav_name in wav_names:
sample_counter += 1
wav_path = path.join(set_path, wav_name)
wav_path = os.path.join(set_path, wav_name)
keep, reason = check_wav_file(wav_path, sentence)
if keep:
writer.writerow({
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
'wav_filesize': path.getsize(wav_path),
'transcript': sentence.lower()
})
writer.writerow(
{
"wav_filename": os.path.relpath(
wav_path, CLI_ARGS.base_dir
),
"wav_filesize": os.path.getsize(wav_path),
"transcript": sentence.lower(),
}
)
else:
reasons[reason] += 1
if len(reasons.keys()) > 0:
print('Excluded samples:')
print("Excluded samples:")
for reason, n in reasons.most_common():
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
@ -150,13 +159,29 @@ def download_and_prepare():
def handle_args():
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)')
parser.add_argument('base_dir', help='Directory containing all data')
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file')
parser.add_argument('--keep_archive', type=bool, default=True,
help='If downloaded archives should be kept')
parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
parser.add_argument("base_dir", help="Directory containing all data")
parser.add_argument(
"--max_duration",
type=int,
default=10000,
help="Maximum sample duration in milliseconds",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
parser.add_argument(
"--alphabet",
help="Exclude samples with characters not in provided alphabet file",
)
parser.add_argument(
"--keep_archive",
type=bool,
default=True,
help="If downloaded archives should be kept",
)
return parser.parse_args()

View File

@ -1,29 +1,22 @@
#!/usr/bin/env python
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import random
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))
from util.importers import get_counter, get_imported_samples, print_import_report
import re
from multiprocessing import Pool
from zipfile import ZipFile
import librosa
import progressbar
from os import path
from multiprocessing import Pool
from util.downloader import maybe_download, SIMPLE_BAR
from zipfile import ZipFile
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
from deepspeech_training.util.importers import (
get_counter,
get_imported_samples,
print_import_report,
)
SAMPLE_RATE = 16000
MAX_SECS = 10
@ -37,7 +30,7 @@ ARCHIVE_URL = (
def _download_and_preprocess_data(target_dir):
# Making path absolute
target_dir = path.abspath(target_dir)
target_dir = os.path.abspath(target_dir)
# Conditionally download data
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
# Conditionally extract common voice data
@ -48,8 +41,8 @@ def _download_and_preprocess_data(target_dir):
def _maybe_extract(target_dir, extracted_data, archive_path):
# If target_dir/extracted_data does not exist, extract archive in target_dir
extracted_path = path.join(target_dir, extracted_data)
if not path.exists(extracted_path):
extracted_path = os.path.join(target_dir, extracted_data)
if not os.path.exists(extracted_path):
print(f"No directory {extracted_path} - extracting archive...")
with ZipFile(archive_path, "r") as zipobj:
# Extract all the contents of zip file in current directory
@ -59,15 +52,17 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
def _maybe_convert_sets(target_dir, extracted_data):
extracted_dir = path.join(target_dir, extracted_data, "wav48")
txt_dir = path.join(target_dir, extracted_data, "txt")
extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
txt_dir = os.path.join(target_dir, extracted_data, "txt")
directory = os.path.expanduser(extracted_dir)
srtd = len(sorted(os.listdir(directory)))
all_samples = []
for target in sorted(os.listdir(directory)):
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
all_samples += _maybe_prepare_set(
path.join(extracted_dir, os.path.split(target)[-1])
)
num_samples = len(all_samples)
print(f"Converting wav files to {SAMPLE_RATE}hz...")
@ -81,6 +76,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
_write_csv(extracted_dir, txt_dir, target_dir)
def one_sample(sample):
if is_audio_file(sample):
y, sr = librosa.load(sample, sr=16000)
@ -103,6 +99,7 @@ def _maybe_prepare_set(target_csv):
samples = new_samples
return samples
def _write_csv(extracted_dir, txt_dir, target_dir):
print(f"Writing CSV file")
dset_abs_path = extracted_dir
@ -197,7 +194,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
def is_audio_file(filepath):
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
return any(
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
)
if __name__ == "__main__":

View File

@ -1,24 +1,19 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import codecs
import sys
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import tarfile
import pandas
import re
import unicodedata
import tarfile
import threading
from multiprocessing.pool import ThreadPool
from six.moves import urllib
import unicodedata
import urllib
from glob import glob
from multiprocessing.pool import ThreadPool
from os import makedirs, path
import pandas
from bs4 import BeautifulSoup
from tensorflow.python.platform import gfile
from util.downloader import maybe_download
from deepspeech_training.util.downloader import maybe_download
"""The number of jobs to run in parallel"""
NUM_PARALLEL = 8
@ -26,8 +21,10 @@ NUM_PARALLEL = 8
"""Lambda function returns the filename of a path"""
filename_of = lambda x: path.split(x)[1]
class AtomicCounter(object):
"""A class that atomically increments a counter"""
def __init__(self, start_count=0):
"""Initialize the counter
:param start_count: the number to start counting at
@ -50,6 +47,7 @@ class AtomicCounter(object):
"""Returns the current value of the counter (not atomic)"""
return self.__count
def _parallel_downloader(voxforge_url, archive_dir, total, counter):
"""Generate a function to download a file based on given parameters
This works by currying the above given arguments into a closure
@ -61,6 +59,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
:param counter: an atomic counter to keep track of # of downloaded files
:return: a function that actually downloads a file given these params
"""
def download(d):
"""Binds voxforge_url, archive_dir, total, and counter into this scope
Downloads the given file
@ -68,12 +67,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
of the file to download and file is the name of the file to download
"""
(i, file) = d
download_url = voxforge_url + '/' + file
download_url = voxforge_url + "/" + file
c = counter.increment()
print('Downloading file {} ({}/{})...'.format(i+1, c, total))
print("Downloading file {} ({}/{})...".format(i + 1, c, total))
maybe_download(filename_of(download_url), archive_dir, download_url)
return download
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
"""Generate a function to extract a tar file based on given parameters
This works by currying the above given arguments into a closure
@ -86,6 +87,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
:param counter: an atomic counter to keep track of # of extracted files
:return: a function that actually extracts a tar file given these params
"""
def extract(d):
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
Extracts the given file
@ -95,58 +97,74 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
(i, archive) = d
if i < number_of_test:
dataset_dir = path.join(data_dir, "test")
elif i<number_of_test+number_of_dev:
elif i < number_of_test + number_of_dev:
dataset_dir = path.join(data_dir, "dev")
else:
dataset_dir = path.join(data_dir, "train")
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
if not gfile.Exists(
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
):
c = counter.increment()
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
print("Extracting file {} ({}/{})...".format(i + 1, c, total))
tar = tarfile.open(archive)
tar.extractall(dataset_dir)
tar.close()
return extract
def _download_and_preprocess_data(data_dir):
# Conditionally download data to data_dir
if not path.isdir(data_dir):
makedirs(data_dir)
archive_dir = data_dir+"/archive"
archive_dir = data_dir + "/archive"
if not path.isdir(archive_dir):
makedirs(archive_dir)
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir))
print(
"Downloading Voxforge data set into {} if not already present...".format(
archive_dir
)
)
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit'
voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
html_page = urllib.request.urlopen(voxforge_url)
soup = BeautifulSoup(html_page, 'html.parser')
soup = BeautifulSoup(html_page, "html.parser")
# list all links
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']]
refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
# download files in parallel
print('{} files to download'.format(len(refs)))
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter())
print("{} files to download".format(len(refs)))
downloader = _parallel_downloader(
voxforge_url, archive_dir, len(refs), AtomicCounter()
)
p = ThreadPool(NUM_PARALLEL)
p.map(downloader, enumerate(refs))
# Conditionally extract data to dataset_dir
if not path.isdir(path.join(data_dir,"test")):
makedirs(path.join(data_dir,"test"))
if not path.isdir(path.join(data_dir,"dev")):
makedirs(path.join(data_dir,"dev"))
if not path.isdir(path.join(data_dir,"train")):
makedirs(path.join(data_dir,"train"))
if not path.isdir(os.path.join(data_dir, "test")):
makedirs(os.path.join(data_dir, "test"))
if not path.isdir(os.path.join(data_dir, "dev")):
makedirs(os.path.join(data_dir, "dev"))
if not path.isdir(os.path.join(data_dir, "train")):
makedirs(os.path.join(data_dir, "train"))
tarfiles = glob(path.join(archive_dir, "*.tgz"))
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
number_of_files = len(tarfiles)
number_of_test = number_of_files//100
number_of_dev = number_of_files//100
number_of_test = number_of_files // 100
number_of_dev = number_of_files // 100
# extract tars in parallel
print("Extracting Voxforge data set into {} if not already present...".format(data_dir))
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter())
print(
"Extracting Voxforge data set into {} if not already present...".format(
data_dir
)
)
extracter = _parallel_extracter(
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
)
p.map(extracter, enumerate(tarfiles))
# Generate data set
@ -156,42 +174,50 @@ def _download_and_preprocess_data(data_dir):
train_files = _generate_dataset(data_dir, "train")
# Write sets to disk as CSV files
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False)
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False)
train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
def _generate_dataset(data_dir, data_set):
extracted_dir = path.join(data_dir, data_set)
files = []
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")):
if path.isdir(path.join(promts_file[:-11],"wav")):
with codecs.open(promts_file, 'r', 'utf-8') as f:
for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
if path.isdir(os.path.join(promts_file[:-11], "wav")):
with codecs.open(promts_file, "r", "utf-8") as f:
for line in f:
id = line.split(' ')[0].split('/')[-1]
sentence = ' '.join(line.split(' ')[1:])
sentence = re.sub("[^a-z']"," ",sentence.strip().lower())
id = line.split(" ")[0].split("/")[-1]
sentence = " ".join(line.split(" ")[1:])
sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
transcript = ""
for token in sentence.split(" "):
word = token.strip()
if word!="" and word!=" ":
if word != "" and word != " ":
transcript += word + " "
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
.encode("ascii", "ignore") \
.decode("ascii", "ignore")
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav")
transcript = (
unicodedata.normalize("NFKD", transcript.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
if gfile.Exists(wav_file):
wav_filesize = path.getsize(wav_file)
# remove audios that are shorter than 0.5s and longer than 20s.
# remove audios that are too short for transcript.
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \
wav_filesize/len(transcript)>1400:
files.append((path.abspath(wav_file), wav_filesize, transcript))
if (
(wav_filesize / 32000) > 0.5
and (wav_filesize / 32000) < 20
and transcript != ""
and wav_filesize / len(transcript) > 1400
):
files.append(
(os.path.abspath(wav_file), wav_filesize, transcript)
)
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
return pandas.DataFrame(
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
)
if __name__=="__main__":
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -1,15 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow.compat.v1 as tfv1
import sys
import tensorflow.compat.v1 as tfv1
def main():
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
graph_def = tfv1.GraphDef()
graph_def.ParseFromString(fin.read())
print('\n'.join(sorted(set(n.op for n in graph_def.node))))
print("\n".join(sorted(set(n.op for n in graph_def.node))))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -3,19 +3,13 @@
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
Use "python3 build_sdb.py -h" for help
"""
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import random
import argparse
import random
import sys
from util.sample_collections import samples_from_file, LabeledSample
from util.audio import AUDIO_TYPE_PCM
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
def play_sample(samples, index):
@ -24,7 +18,7 @@ def play_sample(samples, index):
if CLI_ARGS.random:
index = random.randint(0, len(samples))
elif index >= len(samples):
print('No sample with index {}'.format(CLI_ARGS.start))
print("No sample with index {}".format(CLI_ARGS.start))
sys.exit(1)
sample = samples[index]
print('Sample "{}"'.format(sample.sample_id))
@ -50,13 +44,28 @@ def play_collection():
def handle_args():
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
'and DeepSpeech CSV files')
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
parser.add_argument('--start', type=int, default=0,
help='Sample index to start at (negative numbers are relative to the end of the collection)')
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
parser = argparse.ArgumentParser(
description="Tool for playing samples from Sample Databases (SDB files) "
"and DeepSpeech CSV files"
)
parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
parser.add_argument(
"--start",
type=int,
default=0,
help="Sample index to start at (negative numbers are relative to the end of the collection)",
)
parser.add_argument(
"--number",
type=int,
default=-1,
help="Number of samples to play (-1 for endless)",
)
parser.add_argument(
"--random",
action="store_true",
help="If samples should be played in random order",
)
return parser.parse_args()
@ -70,5 +79,5 @@ if __name__ == "__main__":
try:
play_collection()
except KeyboardInterrupt:
print(' Stopped')
print(" Stopped")
sys.exit(0)

View File

@ -1,17 +1,11 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
import argparse
import shutil
import sys
from util.text import Alphabet, UTF8Alphabet
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet

View File

@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
.. code-block::
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/
$ python3 -m venv $HOME/tmp/deepspeech-train-venv/
Once this command completes successfully, the environment will be ready to be activated.
@ -46,7 +46,7 @@ Install the required dependencies using ``pip3``\ :
.. code-block:: bash
cd DeepSpeech
pip3 install -r requirements.txt
pip3 install -e .
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
@ -70,7 +70,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
.. code-block:: bash
pip3 uninstall tensorflow
pip3 install 'tensorflow-gpu==1.15.0'
pip3 install 'tensorflow-gpu==1.15.2'
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.

158
evaluate.py Executable file → Normal file
View File

@ -2,155 +2,11 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.evaluate_tools import calculate_and_print_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
if __name__ == '__main__':
create_flags()
absl.app.run(main)
try:
from deepspeech_training import evaluate as ds_evaluate
except ImportError:
print('Training package is not installed. See training documentation.')
raise
ds_evaluate.run_script()

View File

@ -10,13 +10,12 @@ import csv
import os
import sys
from functools import partial
from six.moves import zip, range
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from deepspeech import Model
from util.evaluate_tools import calculate_and_print_report
from util.flags import create_flags
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
from deepspeech_training.util.flags import create_flags
from functools import partial
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
from six.moves import zip, range
r'''
This module should be self-contained:

View File

@ -2,19 +2,18 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import sys
import optuna
import absl.app
from ds_ctcdecoder import Scorer
import optuna
import sys
import tensorflow.compat.v1 as tfv1
from DeepSpeech import create_model
from evaluate import evaluate
from util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.evaluate_tools import wer_cer_batch
from deepspeech_training.evaluate import evaluate
from deepspeech_training.train import create_model
from deepspeech_training.util.config import Config, initialize_globals
from deepspeech_training.util.flags import create_flags, FLAGS
from deepspeech_training.util.logging import log_error
from deepspeech_training.util.evaluate_tools import wer_cer_batch
from ds_ctcdecoder import Scorer
def character_based():

View File

@ -1,24 +0,0 @@
# Main training requirements
tensorflow == 1.15.2
numpy == 1.18.1
progressbar2
six
pyxdg
attrdict
absl-py
semver
opuslib == 2.0.0
# Requirements for building native_client files
setuptools
# Requirements for importers
sox
bs4
pandas
requests
librosa
soundfile
# Requirements for optimizer
optuna

60
setup.py Normal file
View File

@ -0,0 +1,60 @@
from pathlib import Path
from setuptools import find_packages, setup
def main():
version_file = Path(__file__).parent / 'VERSION'
with open(str(version_file)) as fin:
version = fin.read().strip()
setup(
name='deepspeech_training',
version=version,
description='Training code for mozilla DeepSpeech',
url='https://github.com/mozilla/DeepSpeech',
author='Mozilla',
license='MPL-2.0',
# Classifiers help users find your project by categorizing it.
#
# For a list of valid classifiers, see https://pypi.org/classifiers/
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Topic :: Multimedia :: Sound/Audio :: Speech',
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
'Programming Language :: Python :: 3',
],
package_dir={'': 'training'},
packages=find_packages(where='training'),
python_requires='>=3.5, <4',
install_requires=[
'tensorflow == 1.15.2',
'numpy == 1.18.1',
'progressbar2',
'six',
'pyxdg',
'attrdict',
'absl-py',
'semver',
'opuslib == 2.0.0',
'optuna',
'sox',
'bs4',
'pandas',
'requests',
'librosa',
'soundfile',
],
# If there are data files included in your packages that need to be
# installed, specify them here.
package_data={
'deepspeech_training': [
'VERSION',
'GRAPH_VERSION',
],
},
)
if __name__ == '__main__':
main()

View File

@ -1,10 +1,29 @@
#!/usr/bin/env python3
import argparse
import os
import functools
import pandas
from deepspeech_training.util.helpers import secs_to_hours
from pathlib import Path
def read_csvs(csv_files):
# Relative paths are relative to CSV location
def absolutify(csv, path):
path = Path(path)
if path.is_absolute():
return str(path)
return str(csv.parent / path)
sets = []
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True)
from util.helpers import secs_to_hours
from util.feeding import read_csvs
def main():
parser = argparse.ArgumentParser()
@ -14,20 +33,16 @@ def main():
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
csv_dataframe = read_csvs(in_files)
total_bytes = csv_dataframe['wav_filesize'].sum()
total_files = len(csv_dataframe.index)
total_files = len(csv_dataframe)
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
bytes_without_headers = total_bytes - 44 * total_files
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8)
print('total_bytes', total_bytes)
print('total_files', total_files)
print('bytes_without_headers', bytes_without_headers)
print('total_time', secs_to_hours(total_time))
print('Total bytes:', total_bytes)
print('Total files:', total_files)
print('Total time:', secs_to_hours(total_seconds))
if __name__ == '__main__':
main()

View File

@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
set -o pipefail
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail
which deepspeech

View File

@ -17,7 +17,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")

View File

@ -16,7 +16,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/

View File

@ -14,7 +14,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
pushd ${HOME}/DeepSpeech/ds
pip install --upgrade . | cat
popd
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/

View File

@ -1,10 +1,14 @@
import unittest
from argparse import Namespace
from .importers import validate_label_eng, get_validate_label
from deepspeech_training.util.importers import validate_label_eng, get_validate_label
from pathlib import Path
def from_here(path):
here = Path(__file__)
return here.parent / path
class TestValidateLabelEng(unittest.TestCase):
def test_numbers(self):
label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None)
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
self.assertEqual(f('toto1234[{[{[]'), None)
def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
f = get_validate_label(args)
self.assertEqual(f, None)
def test_get_validate_label(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
f = get_validate_label(args)
l = f('toto')
self.assertEqual(l, 'toto')

View File

@ -1,7 +1,7 @@
import unittest
import os
from .text import Alphabet
from deepspeech_training.util.text import Alphabet
class TestAlphabetParsing(unittest.TestCase):

View File

@ -0,0 +1 @@
../../GRAPH_VERSION

View File

@ -0,0 +1 @@
../../VERSION

View File

View File

@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import json
import sys
from multiprocessing import cpu_count
import absl.app
import progressbar
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version
from .util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):
r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``.
"""
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [[] for _ in range(sp_tuple[2][0])]
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x,
batch_size=FLAGS.test_batch_size,
seq_length=batch_x_len,
dropout=no_dropout)
# Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step()
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_test(init_op, dataset):
wav_filenames = []
losses = []
predictions = []
ground_truths = []
bar = create_progressbar(prefix='Test epoch | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Test epoch...')
step_count = 0
# Initialize iterator to the appropriate dataset
session.run(init_op)
# First pass, compute losses and transposed logits for decoding
while True:
try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
losses.extend(batch_loss)
step_count += 1
bar.update(step_count)
bar.finish()
# Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
return test_samples
samples = []
for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv))
samples.extend(run_test(init_op, dataset=csv))
return samples
def main(_):
initialize_globals()
if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

View File

@ -0,0 +1,936 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app
import json
import numpy as np
import progressbar
import shutil
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
check_ctcdecoder_version()
# Graph Creation
# ==============
def variable_on_cpu(name, shape, initializer):
r"""
Next we concern ourselves with graph creation.
However, before we do so we must introduce a utility function ``variable_on_cpu()``
used to create a variable in CPU memory.
"""
# Use the /cpu:0 device for scoped operations
with tf.device(Config.cpu_device):
# Create or get apropos variable
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
return var
def create_overlapping_windows(batch_x):
batch_size = tf.shape(input=batch_x)[0]
window_width = 2 * Config.n_context + 1
num_channels = Config.n_input
# Create a constant convolution filter using an identity matrix, so that the
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
return batch_x
def dense(name, x, units, dropout_rate=None, relu=True):
with tfv1.variable_scope(name):
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
if relu:
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
if dropout_rate is not None:
output = tf.nn.dropout(output, rate=dropout_rate)
return output
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
output, output_state = fw_cell(inputs=x,
dtype=tf.float32,
sequence_length=seq_length,
initial_state=previous_state)
return output, output_state
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
# the object it creates the variables, and then you just call it several times
# to enable variable re-use. Because all of our code is structure in an old
# school TensorFlow structure where you can just call tf.get_variable again with
# reuse=True to reuse variables, we can't easily make use of the object oriented
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
# emulating a static function variable.
if not rnn_impl_cudnn_rnn.cell:
# Forward direction cell:
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
num_units=Config.n_cell_dim,
input_mode='linear_input',
direction='unidirectional',
dtype=tf.float32)
rnn_impl_cudnn_rnn.cell = fw_cell
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
sequence_lengths=seq_length)
return output, output_state
rnn_impl_cudnn_rnn.cell = None
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
# Forward direction cell:
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
forget_bias=0,
reuse=reuse,
name='cudnn_compatible_lstm_cell')
# Split rank N tensor into list of rank N-1 tensors
x = [x[l] for l in range(x.shape[0])]
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
inputs=x,
sequence_length=seq_length,
initial_state=previous_state,
dtype=tf.float32,
scope='cell_0')
output = tf.concat(output, 0)
return output, output_state
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
layers = {}
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
if not batch_size:
batch_size = tf.shape(input=batch_x)[0]
# Create overlapping feature windows if needed
if overlap:
batch_x = create_overlapping_windows(batch_x)
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x
# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
# Run through parametrized RNN implementation, as we use different RNNs
# for training and inference
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
output = tf.reshape(output, [-1, Config.n_cell_dim])
layers['rnn_output'] = output
layers['rnn_output_state'] = output_state
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
layers['raw_logits'] = layer_6
# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6, layers
# Accuracy and Loss
# =================
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# the loss function used by our network should be the CTC loss function
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
# Calculate the logits of the batch
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(input_tensor=total_loss)
# Finally we return the average loss
return avg_loss, non_finite_files
# Adam Optimization
# =================
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
def create_optimizer(learning_rate_var):
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
beta1=FLAGS.beta1,
beta2=FLAGS.beta2,
epsilon=FLAGS.epsilon)
return optimizer
# Towers
# ======
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
# not present when using a single GPU, that facilitate the multi-GPU use case.
# In particular, one must introduce a means to isolate the inference and gradient
# calculations on the various GPU's.
# The abstraction we intoduce for this purpose is called a 'tower'.
# A tower is specified by two properties:
# * **Scope** - A scope, as provided by `tf.name_scope()`,
# is a means to isolate the operations within a tower.
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
# * **Device** - A hardware device, as provided by `tf.device()`,
# on which all operations within the tower execute.
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
def get_tower_results(iterator, optimizer, dropout_rates):
r'''
With this preliminary step out of the way, we can for each GPU introduce a
tower for which's batch we calculate and return the optimization gradients
and the average loss across towers.
'''
# To calculate the mean of the losses
tower_avg_losses = []
# Tower gradients to return
tower_gradients = []
# Aggregate any non finite files in the batches
tower_non_finite_files = []
with tfv1.variable_scope(tfv1.get_variable_scope()):
# Loop over available_devices
for i in range(len(Config.available_devices)):
# Execute operations of tower i on device i
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)
# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
# Retain tower's gradients
tower_gradients.append(gradients)
tower_non_finite_files.append(non_finite_files)
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
# Return gradients and the average loss
return tower_gradients, avg_loss_across_towers, all_non_finite_files
def average_gradients(tower_gradients):
r'''
A routine for computing each variable's average of the gradients obtained from the GPUs.
Note also that this code acts as a synchronization point as it requires all
GPUs to be finished with their mini-batch before it can run to completion.
'''
# List of average gradients to return to the caller
average_grads = []
# Run this on cpu_device to conserve GPU memory
with tf.device(Config.cpu_device):
# Loop over gradient/variable pairs from all towers
for grad_and_vars in zip(*tower_gradients):
# Introduce grads to store the gradients for the current variable
grads = []
# Loop over the gradients for the current variable
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(input_tensor=grad, axis=0)
# Create a gradient/variable tuple for the current variable with its average gradient
grad_and_var = (grad, grad_and_vars[0][1])
# Add the current tuple to average_grads
average_grads.append(grad_and_var)
# Return result to caller
return average_grads
# Logging
# =======
def log_variable(variable, gradient=None):
r'''
We introduce a function for logging a tensor variable's current state.
It logs scalar values for the mean, standard deviation, minimum and maximum.
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
'''
name = variable.name.replace(':', '_')
mean = tf.reduce_mean(input_tensor=variable)
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
tfv1.summary.histogram(name=name, values=variable)
if gradient is not None:
if isinstance(gradient, tf.IndexedSlices):
grad_values = gradient.values
else:
grad_values = gradient
if grad_values is not None:
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
def log_grads_and_vars(grads_and_vars):
r'''
Let's also introduce a helper function for logging collections of gradient/variable tuples.
'''
for gradient, variable in grads_and_vars:
log_variable(variable, gradient=gradient)
def train():
do_cache_dataset = True
# pylint: disable=too-many-boolean-expressions
if (FLAGS.data_aug_features_multiplicative > 0 or
FLAGS.data_aug_features_additive > 0 or
FLAGS.augmentation_spec_dropout_keeprate < 1 or
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
do_cache_dataset = False
exception_box = ExceptionBox()
# Create training and validation datasets
train_set = create_dataset(FLAGS.train_files.split(','),
batch_size=FLAGS.train_batch_size,
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
output_classes=tfv1.data.get_output_classes(train_set))
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev_files:
dev_sources = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([source],
batch_size=FLAGS.dev_batch_size,
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
# Dropout
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
dropout_feed_dict = {
dropout_rates[0]: FLAGS.dropout_rate,
dropout_rates[1]: FLAGS.dropout_rate2,
dropout_rates[2]: FLAGS.dropout_rate3,
dropout_rates[3]: FLAGS.dropout_rate4,
dropout_rates[4]: FLAGS.dropout_rate5,
dropout_rates[5]: FLAGS.dropout_rate6,
}
no_dropout_feed_dict = {
rate: 0. for rate in dropout_rates
}
# Building the graph
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
optimizer = create_optimizer(learning_rate_var)
# Enable mixed precision training
if FLAGS.automatic_mixed_precision:
log_info('Enabling automatic mixed precision training.')
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
# Average tower gradients across GPUs
avg_tower_gradients = average_gradients(gradients)
log_grads_and_vars(avg_tower_gradients)
# global_step is automagically incremented by the optimizer
global_step = tfv1.train.get_or_create_global_step()
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
# Summaries
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
}
# Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string())
with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
train_op = apply_gradient_op if is_train else []
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
total_loss = 0.0
step_count = 0
step_summary_writer = step_summary_writers.get(set_name)
checkpoint_time = time.time()
# Setup progress bar
class LossWidget(progressbar.widgets.FormatLabel):
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
widgets = [' | ', progressbar.widgets.Timer(),
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)
# Batch loop
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
log_info("Ignoring sparse warp error: {}".format(err))
continue
raise
except tf.errors.OutOfRangeError:
exception_box.raise_if_set()
break
if problem_files.size > 0:
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
log_error('The following files caused an infinite (or NaN) '
'loss: {}'.format(','.join(problem_files)))
total_loss += batch_loss
step_count += 1
pbar.update(step_count)
step_summary_writer.add_summary(step_summary, current_step)
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
checkpoint_time = time.time()
pbar.finish()
mean_loss = total_loss / step_count if step_count > 0 else 0.0
return mean_loss, step_count
log_info('STARTING Optimization')
train_start_time = datetime.utcnow()
best_dev_loss = float('inf')
dev_losses = []
epochs_without_improvement = 0
try:
for epoch in range(FLAGS.epochs):
# Training
log_progress('Training epoch %d...' % epoch)
train_loss, _ = run_set('train', epoch, train_init_op)
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev_files:
# Validation
dev_loss = 0.0
total_steps = 0
for source, init_op in zip(dev_sources, dev_init_ops):
log_progress('Validating epoch %d on %s...' % (epoch, source))
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
dev_loss += set_loss * steps
total_steps += steps
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
dev_loss = dev_loss / total_steps
dev_losses.append(dev_loss)
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
# the improvement has to be greater than FLAGS.es_min_delta
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
epochs_without_improvement += 1
else:
epochs_without_improvement = 0
# Save new best model
if dev_loss < best_dev_loss:
best_dev_loss = dev_loss
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
# Early stopping
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
epochs_without_improvement))
break
# Reduce learning rate on plateau
if (FLAGS.reduce_lr_on_plateau and
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
# If the learning rate was reduced and there is still no improvement
# wait FLAGS.plateau_epochs before the learning rate is reduced again
session.run(reduce_learning_rate_op)
current_learning_rate = learning_rate_var.eval()
log_info('Encountered a plateau, reducing learning rate to {}'.format(
current_learning_rate))
except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
log_debug('Session closed.')
def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in DS_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed.
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below)
previous_state = None
else:
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
# One rate per layer
no_dropout = [None] * 6
if tflite:
rnn_impl = rnn_impl_static_rnn
else:
rnn_impl = rnn_impl_lstmblockfusedcell
logits, layers = create_model(batch_x=input_tensor,
batch_size=batch_size,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
rnn_impl=rnn_impl)
# TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension
if tflite:
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
if n_steps > 0:
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
},
{
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if not FLAGS.export_tflite:
inputs['input_lengths'] = seq_length
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
output_names = output_names_tensors + output_names_ops
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(),
output_node_names=output_names)
frozen_graph = tfv1.graph_util.extract_sub_graph(
graph_def=frozen_graph,
dest_nodes=output_names)
if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Models exported at %s' % (FLAGS.export_dir))
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
FLAGS.export_author_id,
FLAGS.export_model_name,
FLAGS.export_model_version))
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
with open(metadata_fname, 'w') as f:
f.write('---\n')
f.write('author: {}\n'.format(FLAGS.export_author_id))
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
f.write('license: {}\n'.format(FLAGS.export_license))
f.write('language: {}\n'.format(FLAGS.export_language))
f.write('runtime: {}\n'.format(model_runtime))
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
f.write('---\n')
f.write('{}\n'.format(FLAGS.export_description))
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
zip_filename = os.path.dirname(export_dir)
shutil.copy(FLAGS.scorer_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive))
def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# Evaluate
features = create_overlapping_windows(features).eval(session=session)
features_len = features_len.eval(session=session)
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n)
# Print highest probability result
print(decoded[0][1])
def main(_):
initialize_globals()
if FLAGS.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test_files:
tfv1.reset_default_graph()
test()
if FLAGS.export_dir and not FLAGS.export_zip:
tfv1.reset_default_graph()
export()
if FLAGS.export_zip:
tfv1.reset_default_graph()
FLAGS.export_tflite = True
if os.listdir(FLAGS.export_dir):
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
sys.exit(1)
export()
package_zip()
if FLAGS.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
def run_script():
create_flags()
absl.app.run(main)
if __name__ == '__main__':
run_script()

View File

@ -5,7 +5,7 @@ import tempfile
import collections
import numpy as np
from util.helpers import LimitingPool
from .helpers import LimitingPool
DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1

View File

@ -2,8 +2,8 @@ import sys
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from util.flags import FLAGS
from util.logging import log_info, log_error, log_warn
from .flags import FLAGS
from .logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path):

View File

@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict
from xdg import BaseDirectory as xdg
from util.flags import FLAGS
from util.gpu import get_available_gpus
from util.logging import log_error
from util.text import Alphabet, UTF8Alphabet
from util.helpers import parse_file_size
from .flags import FLAGS
from .gpu import get_available_gpus
from .logging import log_error
from .text import Alphabet, UTF8Alphabet
from .helpers import parse_file_size
class ConfigSingleton:
_config = None

View File

@ -7,8 +7,8 @@ import numpy as np
from attrdict import AttrDict
from util.flags import FLAGS
from util.text import levenshtein
from .flags import FLAGS
from .text import levenshtein
def pmap(fun, iterable):

View File

@ -8,13 +8,13 @@ import tensorflow as tf
from tensorflow.python.ops import gen_audio_ops as contrib_audio
from util.config import Config
from util.text import text_to_char_array
from util.flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
from util.sample_collections import samples_from_files
from util.helpers import remember_exception, MEGABYTE
from .config import Config
from .text import text_to_char_array
from .flags import FLAGS
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
from .sample_collections import samples_from_files
from .helpers import remember_exception, MEGABYTE
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):

View File

@ -4,7 +4,7 @@ import os
import re
import sys
from util.helpers import secs_to_hours
from .helpers import secs_to_hours
from collections import Counter
def get_counter():

View File

@ -3,7 +3,7 @@ from __future__ import print_function
import progressbar
import sys
from util.flags import FLAGS
from .flags import FLAGS
# Logging functions

View File

@ -5,8 +5,8 @@ import json
from pathlib import Path
from functools import partial
from util.helpers import MEGABYTE, GIGABYTE, Interleaved
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
from .helpers import MEGABYTE, GIGABYTE, Interleaved
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
BIG_ENDIAN = 'big'
INT_SIZE = 4

View File

@ -1,6 +1,7 @@
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from util.sparse_image_warp import sparse_image_warp
from .sparse_image_warp import sparse_image_warp
def augment_freq_time_mask(spectrogram,
frequency_masking_para=30,

View File

@ -0,0 +1,167 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import argparse
import errno
import gzip
import os
import platform
import six.moves.urllib as urllib
import stat
import subprocess
import sys
from pkg_resources import parse_version
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
}
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert artifact_name
assert branch_name is not None
assert branch_name
return TASKCLUSTER_SCHEME % {'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
def maybe_download_tc(target_dir, tc_url, progress=True):
def report_progress(count, block_size, total_size):
percent = (count * block_size * 100) // total_size
sys.stdout.write("\rDownloading: %d%%" % percent)
sys.stdout.flush()
if percent >= 100:
print('\n')
assert target_dir is not None
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
is_gzip = False
if not os.path.isfile(target_file):
print('Downloading %s ...' % tc_url)
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
is_gzip = headers.get('Content-Encoding') == 'gzip'
else:
print('File already exists: %s' % target_file)
if is_gzip:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)
frw.truncate()
return target_file
def maybe_download_tc_bin(**kwargs):
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
final_stat = os.stat(final_file)
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
parser.add_argument('--arch', required=False,
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
parser.add_argument('--artifact', required=False,
default='native_client.tar.xz',
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
parser.add_argument('--source', required=False, default=None,
help='Name of the TaskCluster scheme to use.')
parser.add_argument('--branch', required=False,
help='Branch name to use. Defaulting to current content of VERSION file.')
parser.add_argument('--decoder', action='store_true',
help='Get URL to ds_ctcdecoder Python package.')
args = parser.parse_args()
if not args.target and not args.decoder:
print('Pass either --target or --decoder.')
sys.exit(1)
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
is_ucs2 = sys.maxunicode < 0x10ffff
if not args.arch:
if is_arm:
args.arch = 'arm64' if is_64bit else 'arm'
elif is_mac:
args.arch = 'osx'
else:
args.arch = 'cpu'
if not args.branch:
version_string = read('../VERSION').strip()
ds_version = parse_version(version_string)
args.branch = "v{}".format(version_string)
else:
ds_version = parse_version(args.branch)
if args.decoder:
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(map(str, sys.version_info[0:2]))
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch
)
ctc_arch = args.arch + '-ctc'
print(get_tc_url(ctc_arch, artifact, args.branch))
sys.exit(0)
if args.source is not None:
if args.source in DEFAULT_SCHEMES:
global TASKCLUSTER_SCHEME
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
else:
print('No such scheme: %s' % args.source)
sys.exit(1)
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
if args.artifact == "convert_graphdef_memmapped_format":
convert_graph_file = os.path.join(args.target, args.artifact)
final_stat = os.stat(convert_graph_file)
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
if __name__ == '__main__':
main()

View File

@ -12,13 +12,13 @@ tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)
from multiprocessing import Process, cpu_count
from deepspeech_training.util.audio import AudioFile
from deepspeech_training.util.config import Config, initialize_globals
from deepspeech_training.util.feeding import split_audio_file
from deepspeech_training.util.flags import create_flags, FLAGS
from deepspeech_training.util.logging import log_error, log_info, log_progress, create_progressbar
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import Config, initialize_globals
from util.audio import AudioFile
from util.feeding import split_audio_file
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_info, log_progress, create_progressbar
from multiprocessing import Process, cpu_count
def fail(message, code=1):
@ -27,8 +27,8 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path):
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from util.checkpoints import load_or_init_graph
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from deepspeech_training.util.checkpoints import load_or_init_graph
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:

View File

@ -1,159 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import absolute_import
import os
import subprocess
import csv
from threading import Thread
from time import time
from scipy.interpolate import spline
from six.moves import range
# Do this to be able to use without X
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
class GPUUsage(Thread):
def __init__(self, csvfile=None):
super(GPUUsage, self).__init__()
self._cmd = [ 'nvidia-smi', 'dmon', '-d', '1', '-s', 'pucvmet' ]
self._names = []
self._units = []
self._process = None
self._csv_output = csvfile or os.environ.get('ds_gpu_usage_csv', self.make_basename(prefix='ds-gpu-usage', extension='csv'))
def get_git_desc(self):
return subprocess.check_output(['git', 'describe', '--always', '--abbrev']).strip()
def make_basename(self, prefix, extension):
# Let us assume that this code is executed in the current git clone
return '%s.%s.%s.%s' % (prefix, self.get_git_desc(), int(time()), extension)
def stop(self):
if not self._process:
print("Trying to stop nvidia-smi but no more process, please fix.")
return
print("Ending nvidia-smi monitoring: PID", self._process.pid)
self._process.terminate()
print("Ended nvidia-smi monitoring ...")
def run(self):
print("Starting nvidia-smi monitoring")
# If the system has no CUDA setup, then this will fail.
try:
self._process = subprocess.Popen(self._cmd, stdout=subprocess.PIPE)
except OSError as ex:
print("Unable to start monitoring, check your environment:", ex)
return
writer = None
with open(self._csv_output, 'w') as f:
for line in iter(self._process.stdout.readline, ''):
d = self.ingest(line)
if line.startswith('# '):
if len(self._names) == 0:
self._names = d
writer = csv.DictWriter(f, delimiter=str(','), quotechar=str('"'), fieldnames=d)
writer.writeheader()
continue
if len(self._units) == 0:
self._units = d
continue
else:
assert len(self._names) == len(self._units)
assert len(d) == len(self._names)
assert len(d) > 1
writer.writerow(self.merge_line(d))
f.flush()
def ingest(self, line):
return map(lambda x: x.replace('-', '0'), filter(lambda x: len(x) > 0, map(lambda x: x.strip(), line.split(' ')[1:])))
def merge_line(self, line):
return dict(zip(self._names, line))
class GPUUsageChart():
def __init__(self, source, basename=None):
self._rows = [ 'pwr', 'temp', 'sm', 'mem']
self._titles = {
'pwr': "Power (W)",
'temp': "Temperature (°C)",
'sm': "Streaming Multiprocessors (%)",
'mem': "Memory (%)"
}
self._data = { }.fromkeys(self._rows)
self._csv = source
self._basename = basename or os.environ.get('ds_gpu_usage_charts', 'gpu_usage_%%s_%d.png' % int(time.time()))
# This should make sure we start from anything clean.
plt.close("all")
try:
self.read()
for plot in self._rows:
self.produce_plot(plot)
except IOError as ex:
print("Unable to read", ex)
def append_data(self, row):
for bucket, value in row.iteritems():
if not bucket in self._rows:
continue
if not self._data[bucket]:
self._data[bucket] = {}
gpu = int(row['gpu'])
if not self._data[bucket].has_key(gpu):
self._data[bucket][gpu] = [ value ]
else:
self._data[bucket][gpu] += [ value ]
def read(self):
print("Reading data from", self._csv)
with open(self._csv, 'r') as f:
for r in csv.DictReader(f):
self.append_data(r)
def produce_plot(self, key, with_spline=True):
png = self._basename % (key, )
print("Producing plot for", key, "as", png)
fig, axis = plt.subplots()
data = self._data[key]
if data is None:
print("Data was empty, aborting")
return
x = list(range(len(data[0])))
if with_spline:
x = map(lambda x: float(x), x)
x_sm = np.array(x)
x_smooth = np.linspace(x_sm.min(), x_sm.max(), 300)
for gpu, y in data.iteritems():
if with_spline:
y = map(lambda x: float(x), y)
y_sm = np.array(y)
y_smooth = spline(x, y, x_smooth, order=1)
axis.plot(x_smooth, y_smooth, label='GPU %d' % (gpu))
else:
axis.plot(x, y, label='GPU %d' % (gpu))
axis.legend(loc="upper right", frameon=False)
axis.set_xlabel("Time (s)")
axis.set_ylabel("%s" % self._titles[key])
fig.set_size_inches(24, 18)
plt.title("GPU Usage: %s" % self._titles[key])
plt.savefig(png, dpi=100)
plt.close(fig)

View File

@ -1,168 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import argparse
import platform
import subprocess
import sys
import os
import errno
import stat
import gzip
import six.moves.urllib as urllib
from pkg_resources import parse_version
DEFAULT_SCHEMES = {
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
}
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert artifact_name
assert branch_name is not None
assert branch_name
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
def maybe_download_tc(target_dir, tc_url, progress=True):
def report_progress(count, block_size, total_size):
percent = (count * block_size * 100) // total_size
sys.stdout.write("\rDownloading: %d%%" % percent)
sys.stdout.flush()
if percent >= 100:
print('\n')
assert target_dir is not None
target_dir = os.path.abspath(target_dir)
try:
os.makedirs(target_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise e
assert os.path.isdir(os.path.dirname(target_dir))
tc_filename = os.path.basename(tc_url)
target_file = os.path.join(target_dir, tc_filename)
is_gzip = False
if not os.path.isfile(target_file):
print('Downloading %s ...' % tc_url)
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
is_gzip = headers.get('Content-Encoding') == 'gzip'
else:
print('File already exists: %s' % target_file)
if is_gzip:
with open(target_file, "r+b") as frw:
decompressed = gzip.decompress(frw.read())
frw.seek(0)
frw.write(decompressed)
frw.truncate()
return target_file
def maybe_download_tc_bin(**kwargs):
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
final_stat = os.stat(final_file)
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
parser.add_argument('--arch', required=False,
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
parser.add_argument('--artifact', required=False,
default='native_client.tar.xz',
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
parser.add_argument('--source', required=False, default=None,
help='Name of the TaskCluster scheme to use.')
parser.add_argument('--branch', required=False,
help='Branch name to use. Defaulting to current content of VERSION file.')
parser.add_argument('--decoder', action='store_true',
help='Get URL to ds_ctcdecoder Python package.')
args = parser.parse_args()
if not args.target and not args.decoder:
print('Pass either --target or --decoder.')
exit(1)
is_arm = 'arm' in platform.machine()
is_mac = 'darwin' in sys.platform
is_64bit = sys.maxsize > (2**31 - 1)
is_ucs2 = sys.maxunicode < 0x10ffff
if not args.arch:
if is_arm:
args.arch = 'arm64' if is_64bit else 'arm'
elif is_mac:
args.arch = 'osx'
else:
args.arch = 'cpu'
if not args.branch:
version_string = read('../VERSION').strip()
ds_version = parse_version(version_string)
args.branch = "v{}".format(version_string)
else:
ds_version = parse_version(args.branch)
if args.decoder:
plat = platform.system().lower()
arch = platform.machine()
if plat == 'linux' and arch == 'x86_64':
plat = 'manylinux1'
if plat == 'darwin':
plat = 'macosx_10_10'
m_or_mu = 'mu' if is_ucs2 else 'm'
pyver = ''.join(map(str, sys.version_info[0:2]))
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
ds_version=ds_version,
pyver=pyver,
m_or_mu=m_or_mu,
platform=plat,
arch=arch
)
ctc_arch = args.arch + '-ctc'
print(get_tc_url(ctc_arch, artifact, args.branch))
exit(0)
if args.source is not None:
if args.source in DEFAULT_SCHEMES:
global TASKCLUSTER_SCHEME
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
else:
print('No such scheme: %s' % args.source)
exit(1)
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
if args.artifact == "convert_graphdef_memmapped_format":
convert_graph_file = os.path.join(args.target, args.artifact)
final_stat = os.stat(convert_graph_file)
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
from __future__ import absolute_import, division, print_function
if __name__ == '__main__':
main()
try:
from deepspeech_training.util import taskcluster as dsu_taskcluster
except ImportError:
print('Training package is not installed. See training documentation.')
raise
dsu_taskcluster.main()