Merge pull request #2856 from reuben/training-install
Package training code to avoid sys.path hacks
This commit is contained in:
commit
83d22e591b
4
.isort.cfg
Normal file
4
.isort.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[settings]
|
||||
line_length=80
|
||||
multi_line_output=3
|
||||
default_section=FIRSTPARTY
|
@ -9,7 +9,7 @@ python:
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- stage: cardboard linter
|
||||
- name: cardboard linter
|
||||
install:
|
||||
- pip install --upgrade cardboardlint pylint
|
||||
script:
|
||||
@ -17,9 +17,10 @@ jobs:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
|
||||
fi
|
||||
- stage: python unit tests
|
||||
- name: python unit tests
|
||||
install:
|
||||
- pip install --upgrade -r requirements_tests.txt
|
||||
- pip install --upgrade -r requirements_tests.txt;
|
||||
pip install --upgrade .
|
||||
script:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
python -m unittest;
|
||||
|
937
DeepSpeech.py
937
DeepSpeech.py
@ -2,934 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
output, output_state = fw_cell(inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode='linear_input',
|
||||
direction='unidirectional',
|
||||
dtype=tf.float32)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||
sequence_lengths=seq_length)
|
||||
|
||||
return output, output_state
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope='cell_0')
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers['input_reshaped'] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers['rnn_output'] = output
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||
layers['raw_logits'] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
|
||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# the loss function used by our network should be the CTC loss function
|
||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
# Calculate the logits of the batch
|
||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
# =================
|
||||
|
||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||
# because, generally, it requires less fine-tuning.
|
||||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
return optimizer
|
||||
|
||||
|
||||
# Towers
|
||||
# ======
|
||||
|
||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||
# In particular, one must introduce a means to isolate the inference and gradient
|
||||
# calculations on the various GPU's.
|
||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||
# A tower is specified by two properties:
|
||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||
# is a means to isolate the operations within a tower.
|
||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||
# on which all operations within the tower execute.
|
||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||
|
||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
r'''
|
||||
With this preliminary step out of the way, we can for each GPU introduce a
|
||||
tower for which's batch we calculate and return the optimization gradients
|
||||
and the average loss across towers.
|
||||
'''
|
||||
# To calculate the mean of the losses
|
||||
tower_avg_losses = []
|
||||
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain tower's avg losses
|
||||
tower_avg_losses.append(avg_loss)
|
||||
|
||||
# Compute gradients for model parameters using tower's mini-batch
|
||||
gradients = optimizer.compute_gradients(avg_loss)
|
||||
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
r'''
|
||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||
Note also that this code acts as a synchronization point as it requires all
|
||||
GPUs to be finished with their mini-batch before it can run to completion.
|
||||
'''
|
||||
# List of average gradients to return to the caller
|
||||
average_grads = []
|
||||
|
||||
# Run this on cpu_device to conserve GPU memory
|
||||
with tf.device(Config.cpu_device):
|
||||
# Loop over gradient/variable pairs from all towers
|
||||
for grad_and_vars in zip(*tower_gradients):
|
||||
# Introduce grads to store the gradients for the current variable
|
||||
grads = []
|
||||
|
||||
# Loop over the gradients for the current variable
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension
|
||||
grad = tf.concat(grads, 0)
|
||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||
|
||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||
grad_and_var = (grad, grad_and_vars[0][1])
|
||||
|
||||
# Add the current tuple to average_grads
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
# Return result to caller
|
||||
return average_grads
|
||||
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r'''
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
'''
|
||||
name = variable.name.replace(':', '_')
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r'''
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
'''
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in DS_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r'''
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||
FLAGS.export_author_id,
|
||||
FLAGS.export_model_name,
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||
f.write('runtime: {}\n'.format(model_runtime))
|
||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||
f.write('---\n')
|
||||
f.write('{}\n'.format(FLAGS.export_description))
|
||||
|
||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
try:
|
||||
from deepspeech_training import train as ds_train
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
ds_train.run_script()
|
||||
|
@ -150,7 +150,7 @@ COPY . /DeepSpeech/
|
||||
|
||||
WORKDIR /DeepSpeech
|
||||
|
||||
RUN pip3 --no-cache-dir install -r requirements.txt
|
||||
RUN pip3 --no-cache-dir install .
|
||||
|
||||
# Link DeepSpeech native_client libs to tf folder
|
||||
RUN ln -s /DeepSpeech/native_client /tensorflow
|
||||
|
@ -1,53 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import progressbar
|
||||
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
||||
from util.sample_collections import samples_from_files, DirectSDBWriter
|
||||
from deepspeech_training.util.audio import (
|
||||
AUDIO_TYPE_OPUS,
|
||||
AUDIO_TYPE_WAV,
|
||||
change_audio_types,
|
||||
)
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.sample_collections import (
|
||||
DirectSDBWriter,
|
||||
samples_from_files,
|
||||
)
|
||||
|
||||
AUDIO_TYPE_LOOKUP = {
|
||||
'wav': AUDIO_TYPE_WAV,
|
||||
'opus': AUDIO_TYPE_OPUS
|
||||
}
|
||||
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
|
||||
|
||||
|
||||
def build_sdb():
|
||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
|
||||
with DirectSDBWriter(
|
||||
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
|
||||
) as sdb_writer:
|
||||
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
||||
for sample in bar(
|
||||
change_audio_types(
|
||||
samples, audio_type=audio_type, processes=CLI_ARGS.workers
|
||||
)
|
||||
):
|
||||
sdb_writer.add(sample)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
||||
'from DeepSpeech CSV files and other SDB files')
|
||||
parser.add_argument('sources', nargs='+',
|
||||
help='Source CSV and/or SDB files - '
|
||||
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
|
||||
'already ordered from shortest to longest.')
|
||||
parser.add_argument('target', help='SDB file to create')
|
||||
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help='Audio representation inside target SDB')
|
||||
parser.add_argument('--workers', type=int, default=None,
|
||||
help='Number of encoding SDB workers')
|
||||
parser.add_argument('--unlabeled', action='store_true',
|
||||
help='If to build an SDB with unlabeled (audio only) samples - '
|
||||
'typically used for building noise augmentation corpora')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for building Sample Databases (SDB files) "
|
||||
"from DeepSpeech CSV files and other SDB files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"sources",
|
||||
nargs="+",
|
||||
help="Source CSV and/or SDB files - "
|
||||
"Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
|
||||
"already ordered from shortest to longest.",
|
||||
)
|
||||
parser.add_argument("target", help="SDB file to create")
|
||||
parser.add_argument(
|
||||
"--audio-type",
|
||||
default="opus",
|
||||
choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help="Audio representation inside target SDB",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=None, help="Number of encoding SDB workers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unlabeled",
|
||||
action="store_true",
|
||||
help="If to build an SDB with unlabeled (audio only) samples - "
|
||||
"typically used for building noise augmentation corpora",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
from util.gpu_usage import GPUUsage
|
||||
|
||||
gu = GPUUsage()
|
||||
gu.start()
|
@ -1,10 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
from util.gpu_usage import GPUUsageChart
|
||||
|
||||
GPUUsageChart(sys.argv[1], sys.argv[2])
|
@ -1,20 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import sys
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from google.protobuf import text_format
|
||||
|
||||
|
||||
def main():
|
||||
# Load and export as string
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
|
||||
graph_def = tfv1.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
|
||||
fout.write(text_format.MessageToString(graph_def))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,23 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -25,9 +19,9 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'aidatatang_200zh')
|
||||
main_folder = os.path.join(target_dir, "aidatatang_200zh")
|
||||
|
||||
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')):
|
||||
for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
|
||||
extract(targz, os.path.dirname(targz))
|
||||
|
||||
# Folder structure is now:
|
||||
@ -46,9 +40,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
|
||||
# Since the transcripts themselves can contain spaces, we split on space but
|
||||
# only once, then build a mapping from file name to transcript
|
||||
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt')
|
||||
transcripts_path = os.path.join(
|
||||
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
|
||||
)
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
|
||||
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -57,33 +53,39 @@ def preprocess_data(tgz_file, target_dir):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.splitext(os.path.basename(wav))[0]
|
||||
transcript = transcripts[transcript_key].strip('\n')
|
||||
transcript = transcripts[transcript_key].strip("\n")
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav'))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(
|
||||
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
|
||||
)
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
|
||||
|
||||
# Trim train set to under 10s by removing the last couple hundred samples
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/62/
|
||||
parser = get_importers_parser(description='Import aidatatang_200zh corpus')
|
||||
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import aidatatang_200zh corpus")
|
||||
parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -1,23 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -25,10 +19,10 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'data_aishell')
|
||||
main_folder = os.path.join(target_dir, "data_aishell")
|
||||
|
||||
wav_archives_folder = os.path.join(main_folder, 'wav')
|
||||
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')):
|
||||
wav_archives_folder = os.path.join(main_folder, "wav")
|
||||
for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
|
||||
extract(targz, main_folder)
|
||||
|
||||
# Folder structure is now:
|
||||
@ -45,9 +39,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
|
||||
# Since the transcripts themselves can contain spaces, we split on space but
|
||||
# only once, then build a mapping from file name to transcript
|
||||
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt')
|
||||
transcripts_path = os.path.join(
|
||||
main_folder, "transcript", "aishell_transcript_v0.8.txt"
|
||||
)
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
|
||||
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -56,33 +52,37 @@ def preprocess_data(tgz_file, target_dir):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.splitext(os.path.basename(wav))[0]
|
||||
transcript = transcripts[transcript_key].strip('\n')
|
||||
transcript = transcripts[transcript_key].strip("\n")
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav'))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
|
||||
|
||||
# Trim train set to under 10s by removing the last couple hundred samples
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# http://www.openslr.org/33/
|
||||
parser = get_importers_parser(description='Import AISHELL corpus')
|
||||
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import AISHELL corpus")
|
||||
parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
109
bin/import_cv.py
109
bin/import_cv.py
@ -1,34 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import sox
|
||||
import tarfile
|
||||
import os
|
||||
import subprocess
|
||||
import progressbar
|
||||
|
||||
import tarfile
|
||||
from glob import glob
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
ARCHIVE_DIR_NAME = 'cv_corpus_v1'
|
||||
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz'
|
||||
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "cv_corpus_v1"
|
||||
ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
|
||||
ARCHIVE_URL = (
|
||||
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
|
||||
)
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract common voice data
|
||||
@ -36,56 +37,70 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(path.join(extracted_dir, '*.csv')):
|
||||
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
|
||||
_maybe_convert_set(
|
||||
extracted_dir,
|
||||
source_csv,
|
||||
os.path.join(target_dir, os.path.split(source_csv)[-1]),
|
||||
)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
mp3_filename = sample[0]
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
|
||||
)
|
||||
file_size = -1
|
||||
if path.exists(wav_filename):
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = validate_label(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
print()
|
||||
if path.exists(target_csv):
|
||||
if os.path.exists(target_csv):
|
||||
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
|
||||
return
|
||||
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
|
||||
@ -93,14 +108,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
with open(source_csv) as source_csv_file:
|
||||
reader = csv.DictReader(source_csv_file)
|
||||
for row in reader:
|
||||
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
|
||||
samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
|
||||
|
||||
# Mutable counters for the concurrent embedded routine
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
rows = []
|
||||
|
||||
print('Importing mp3 files...')
|
||||
print("Importing mp3 files...")
|
||||
pool = Pool()
|
||||
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
|
||||
@ -112,21 +127,28 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
pool.join()
|
||||
|
||||
print('Writing "%s"...' % target_csv)
|
||||
with open(target_csv, 'w') as target_csv_file:
|
||||
with open(target_csv, "w") as target_csv_file:
|
||||
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
@ -134,5 +156,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
except sox.core.SoxError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,90 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Broadly speaking, this script takes the audio downloaded from Common Voice
|
||||
for a certain language, in addition to the *.tsv files output by CorporaCreator,
|
||||
and the script formats the data and transcripts to be in a state usable by
|
||||
DeepSpeech.py
|
||||
Use "python3 import_cv2.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
"""
|
||||
import csv
|
||||
import sox
|
||||
import os
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.text import Alphabet
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
|
||||
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
|
||||
for dataset in ["train", "test", "dev", "validated", "other"]:
|
||||
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
|
||||
if os.path.isfile(input_tsv):
|
||||
print("Loading TSV file: ", input_tsv)
|
||||
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
mp3_filename = sample[0]
|
||||
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
|
||||
mp3_filename += ".mp3"
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter_fun(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((os.path.split(wav_filename)[-1], file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||
output_csv = os.path.join(
|
||||
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
|
||||
)
|
||||
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
with open(input_tsv, encoding='utf-8') as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
||||
with open(input_tsv, encoding="utf-8") as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter="\t")
|
||||
for row in reader:
|
||||
samples.append((path.join(audio_dir, row['path']), row['sentence']))
|
||||
samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -101,26 +107,38 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(output_csv, 'w', encoding='utf-8') as output_csv_file:
|
||||
print('Writing CSV file for DeepSpeech.py as: ', output_csv)
|
||||
with open(output_csv, "w", encoding="utf-8") as output_csv_file:
|
||||
print("Writing CSV file for DeepSpeech.py as: ", output_csv)
|
||||
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
if space_after_every_character:
|
||||
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": " ".join(transcript),
|
||||
}
|
||||
)
|
||||
else:
|
||||
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
@ -130,24 +148,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
|
||||
PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
|
||||
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
|
||||
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
|
||||
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
|
||||
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
|
||||
PARSER.add_argument(
|
||||
"--audio_dir",
|
||||
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--space_after_every_character",
|
||||
action="store_true",
|
||||
help="To help transcript join by white space",
|
||||
)
|
||||
|
||||
PARAMS = PARSER.parse_args()
|
||||
validate_label = get_validate_label(PARAMS)
|
||||
|
||||
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
|
||||
AUDIO_DIR = (
|
||||
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
|
||||
)
|
||||
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
|
||||
|
||||
def label_filter_fun(label):
|
||||
if PARAMS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -1,25 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
import librosa
|
||||
import pandas
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
# Prerequisite: Having the sph2pipe tool in your PATH:
|
||||
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import pandas
|
||||
import subprocess
|
||||
import unicodedata
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
@ -29,33 +24,55 @@ def _download_and_preprocess_data(data_dir):
|
||||
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
|
||||
|
||||
# Conditionally split Fisher wav data
|
||||
all_2004 = _split_wav_and_sentences(data_dir,
|
||||
original_data="fisher-2004-wav",
|
||||
converted_data="fisher-2004-split-wav",
|
||||
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"))
|
||||
all_2005 = _split_wav_and_sentences(data_dir,
|
||||
original_data="fisher-2005-wav",
|
||||
converted_data="fisher-2005-split-wav",
|
||||
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"))
|
||||
all_2004 = _split_wav_and_sentences(
|
||||
data_dir,
|
||||
original_data="fisher-2004-wav",
|
||||
converted_data="fisher-2004-split-wav",
|
||||
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
|
||||
)
|
||||
all_2005 = _split_wav_and_sentences(
|
||||
data_dir,
|
||||
original_data="fisher-2005-wav",
|
||||
converted_data="fisher-2005-split-wav",
|
||||
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
|
||||
)
|
||||
|
||||
# The following files have incorrect transcripts that are much longer than
|
||||
# their audio source. The result is that we end up with more labels than time
|
||||
# slices, which breaks CTC.
|
||||
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct"
|
||||
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those"
|
||||
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want"
|
||||
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
|
||||
all_2004.loc[
|
||||
all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
|
||||
"transcript",
|
||||
] = "correct"
|
||||
all_2004.loc[
|
||||
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
|
||||
"transcript",
|
||||
] = "that's one of those"
|
||||
all_2005.loc[
|
||||
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
|
||||
"transcript",
|
||||
] = "they don't want"
|
||||
all_2005.loc[
|
||||
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
|
||||
"transcript",
|
||||
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
|
||||
|
||||
# The following file is just a short sound and not at all transcribed like provided.
|
||||
# So we just exclude it.
|
||||
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")]
|
||||
all_2004 = all_2004[
|
||||
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
|
||||
]
|
||||
|
||||
# The following file is far too long and would ruin our training batch size.
|
||||
# So we just exclude it.
|
||||
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")]
|
||||
all_2005 = all_2005[
|
||||
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
|
||||
]
|
||||
|
||||
# The following file is too large for its transcript, so we just exclude it.
|
||||
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")]
|
||||
all_2004 = all_2004[
|
||||
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
|
||||
]
|
||||
|
||||
# Conditionally split Fisher data into train/validation/test sets
|
||||
train_2004, dev_2004, test_2004 = _split_sets(all_2004)
|
||||
@ -71,6 +88,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
source_dir = os.path.join(data_dir, original_data)
|
||||
target_dir = os.path.join(data_dir, converted_data)
|
||||
@ -88,10 +106,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
for filename in fnmatch.filter(filenames, "*.sph"):
|
||||
sph_file = os.path.join(root, filename)
|
||||
for channel in ["1", "2"]:
|
||||
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav"
|
||||
wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "_c"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(target_dir, wav_filename)
|
||||
print("converting {} to {}".format(sph_file, wav_file))
|
||||
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file])
|
||||
subprocess.check_call(
|
||||
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
|
||||
)
|
||||
|
||||
|
||||
def _parse_transcriptions(trans_file):
|
||||
segments = []
|
||||
@ -109,18 +135,23 @@ def _parse_transcriptions(trans_file):
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
segments.append({
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"speaker": speaker,
|
||||
"transcript": transcript,
|
||||
})
|
||||
segments.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"speaker": speaker,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
return segments
|
||||
|
||||
|
||||
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
|
||||
trans_dir = os.path.join(data_dir, trans_data)
|
||||
source_dir = os.path.join(data_dir, original_data)
|
||||
@ -137,43 +168,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
|
||||
segments = _parse_transcriptions(trans_file)
|
||||
|
||||
# Open wav corresponding to transcription file
|
||||
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]]
|
||||
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames]
|
||||
wav_filenames = [
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "_c"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
for channel in ["1", "2"]
|
||||
]
|
||||
wav_files = [
|
||||
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
|
||||
]
|
||||
|
||||
print("splitting {} according to {}".format(wav_files, trans_file))
|
||||
|
||||
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files]
|
||||
origAudios = [
|
||||
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
|
||||
]
|
||||
|
||||
# Loop over segments and split wav_file for each segment
|
||||
for segment in segments:
|
||||
# Create wav segment filename
|
||||
start_time = segment["start_time"]
|
||||
stop_time = segment["stop_time"]
|
||||
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
new_wav_file = os.path.join(target_dir, new_wav_filename)
|
||||
|
||||
channel = 0 if segment["speaker"] == "A:" else 1
|
||||
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file)
|
||||
_split_and_resample_wav(
|
||||
origAudios[channel], start_time, stop_time, new_wav_file
|
||||
)
|
||||
|
||||
new_wav_filesize = os.path.getsize(new_wav_file)
|
||||
transcript = validate_label(segment["transcript"])
|
||||
if transcript != None:
|
||||
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
|
||||
files.append(
|
||||
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
|
||||
)
|
||||
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
def _split_audio(origAudio, start_time, stop_time):
|
||||
audioData, frameRate = origAudio
|
||||
nChannels = len(audioData.shape)
|
||||
startIndex = int(start_time * frameRate)
|
||||
stopIndex = int(stop_time * frameRate)
|
||||
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex]
|
||||
return (
|
||||
audioData[startIndex:stopIndex]
|
||||
if 1 == nChannels
|
||||
else audioData[:, startIndex:stopIndex]
|
||||
)
|
||||
|
||||
|
||||
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
frameRate = origAudio[1]
|
||||
chunkData = _split_audio(origAudio, start_time, stop_time)
|
||||
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
@ -187,9 +248,12 @@ def _split_sets(filelist):
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
|
||||
return (filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end])
|
||||
return (
|
||||
filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import numpy as np
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -26,7 +20,7 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS')
|
||||
main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
|
||||
|
||||
# Folder structure is now:
|
||||
# - ST-CMDS-20170001_1-OS/
|
||||
@ -39,16 +33,16 @@ def preprocess_data(tgz_file, target_dir):
|
||||
for wav in glob.glob(glob_path):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
txt_filename = os.path.splitext(wav_filename)[0] + '.txt'
|
||||
with open(txt_filename, 'r') as fin:
|
||||
txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
|
||||
with open(txt_filename, "r") as fin:
|
||||
transcript = fin.read()
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
return set_files
|
||||
|
||||
# Load all files, then deterministically split into train/dev/test sets
|
||||
all_files = load_set(os.path.join(main_folder, '*.wav'))
|
||||
all_files = load_set(os.path.join(main_folder, "*.wav"))
|
||||
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
|
||||
df.sort_values(by='wav_filename', inplace=True)
|
||||
df.sort_values(by="wav_filename", inplace=True)
|
||||
|
||||
indices = np.arange(0, len(df))
|
||||
np.random.seed(12345)
|
||||
@ -61,29 +55,33 @@ def preprocess_data(tgz_file, target_dir):
|
||||
train_indices = indices[:-10000]
|
||||
|
||||
train_files = df.iloc[train_indices]
|
||||
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
|
||||
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
|
||||
train_files = train_files[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv')
|
||||
print('Saving train set into {}...'.format(dest_csv))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
|
||||
print("Saving train set into {}...".format(dest_csv))
|
||||
train_files.to_csv(dest_csv, index=False)
|
||||
|
||||
dev_files = df.iloc[dev_indices]
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv')
|
||||
print('Saving dev set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
|
||||
print("Saving dev set into {}...".format(dest_csv))
|
||||
dev_files.to_csv(dest_csv, index=False)
|
||||
|
||||
test_files = df.iloc[test_indices]
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv')
|
||||
print('Saving test set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
|
||||
print("Saving test set into {}...".format(dest_csv))
|
||||
test_files.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/38/
|
||||
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
|
||||
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
|
||||
parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import math
|
||||
import urllib
|
||||
import logging
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from os import path
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
|
||||
import swifter
|
||||
import pandas as pd
|
||||
from sox import Transformer
|
||||
|
||||
import swifter
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
__version__ = "0.1.0"
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -40,9 +34,7 @@ def parse_args(args):
|
||||
Returns:
|
||||
:obj:`argparse.Namespace`: command line parameters namespace
|
||||
"""
|
||||
parser = get_importers_parser(
|
||||
description="Imports GramVaani data for Deep Speech"
|
||||
)
|
||||
parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
@ -82,6 +74,7 @@ def parse_args(args):
|
||||
)
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def setup_logging(level):
|
||||
"""Setup basic logging
|
||||
Args:
|
||||
@ -92,6 +85,7 @@ def setup_logging(level):
|
||||
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
|
||||
class GramVaaniCSV:
|
||||
"""GramVaaniCSV representing a GramVaani dataset.
|
||||
Args:
|
||||
@ -107,8 +101,17 @@ class GramVaaniCSV:
|
||||
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
|
||||
data = pd.read_csv(
|
||||
os.path.abspath(csv_filename),
|
||||
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
|
||||
usecols=["audio_url","transcript","audio_length"],
|
||||
names=[
|
||||
"piece_id",
|
||||
"audio_url",
|
||||
"transcript_labelled",
|
||||
"transcript",
|
||||
"labels",
|
||||
"content_filename",
|
||||
"audio_length",
|
||||
"user_id",
|
||||
],
|
||||
usecols=["audio_url", "transcript", "audio_length"],
|
||||
skiprows=[0],
|
||||
engine="python",
|
||||
encoding="utf-8",
|
||||
@ -119,6 +122,7 @@ class GramVaaniCSV:
|
||||
_logger.info("Parsed %d lines csv file." % len(data))
|
||||
return data
|
||||
|
||||
|
||||
class GramVaaniDownloader:
|
||||
"""GramVaaniDownloader downloads a GramVaani dataset.
|
||||
Args:
|
||||
@ -138,15 +142,17 @@ class GramVaaniDownloader:
|
||||
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
|
||||
"""
|
||||
mp3_directory = self._pre_download()
|
||||
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
|
||||
self.data.swifter.apply(
|
||||
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
|
||||
)
|
||||
return mp3_directory
|
||||
|
||||
def _pre_download(self):
|
||||
mp3_directory = path.join(self.target_dir, "mp3")
|
||||
if not path.exists(self.target_dir):
|
||||
mp3_directory = os.path.join(self.target_dir, "mp3")
|
||||
if not os.path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(mp3_directory):
|
||||
if not os.path.exists(mp3_directory):
|
||||
_logger.info("Creating directory...%s", mp3_directory)
|
||||
os.mkdir(mp3_directory)
|
||||
return mp3_directory
|
||||
@ -154,13 +160,14 @@ class GramVaaniDownloader:
|
||||
def _download(self, audio_url, transcript, audio_length, mp3_directory):
|
||||
if audio_url == "audio_url":
|
||||
return
|
||||
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
|
||||
if not path.exists(mp3_filename):
|
||||
mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
|
||||
if not os.path.exists(mp3_filename):
|
||||
_logger.debug("Downloading mp3 file...%s", audio_url)
|
||||
urllib.request.urlretrieve(audio_url, mp3_filename)
|
||||
else:
|
||||
_logger.debug("Already downloaded mp3 file...%s", audio_url)
|
||||
|
||||
|
||||
class GramVaaniConverter:
|
||||
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
|
||||
Args:
|
||||
@ -181,37 +188,53 @@ class GramVaaniConverter:
|
||||
wav_directory (os.path): The directory into which the associated wav's were downloaded
|
||||
"""
|
||||
wav_directory = self._pre_convert()
|
||||
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
||||
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
if not path.exists(wav_filename):
|
||||
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
|
||||
wav_filename = os.path.join(
|
||||
wav_directory,
|
||||
os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
|
||||
)
|
||||
if not os.path.exists(wav_filename):
|
||||
_logger.debug(
|
||||
"Converting mp3 file %s to wav file %s"
|
||||
% (mp3_filename, wav_filename)
|
||||
)
|
||||
transformer = Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
||||
transformer.convert(
|
||||
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
|
||||
)
|
||||
transformer.build(str(mp3_filename), str(wav_filename))
|
||||
else:
|
||||
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
_logger.debug(
|
||||
"Already converted mp3 file %s to wav file %s"
|
||||
% (mp3_filename, wav_filename)
|
||||
)
|
||||
return wav_directory
|
||||
|
||||
def _pre_convert(self):
|
||||
wav_directory = path.join(self.target_dir, "wav")
|
||||
if not path.exists(self.target_dir):
|
||||
wav_directory = os.path.join(self.target_dir, "wav")
|
||||
if not os.path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(wav_directory):
|
||||
if not os.path.exists(wav_directory):
|
||||
_logger.info("Creating directory...%s", wav_directory)
|
||||
os.mkdir(wav_directory)
|
||||
return wav_directory
|
||||
|
||||
|
||||
class GramVaaniDataSets:
|
||||
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
|
||||
self.target_dir = target_dir
|
||||
self.wav_directory = wav_directory
|
||||
self.csv_data = gram_vaani_csv.data
|
||||
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
self.valid = pd.DataFrame(
|
||||
columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
self.train = pd.DataFrame(
|
||||
columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
def create(self):
|
||||
self._convert_csv_data_to_raw_data()
|
||||
@ -220,30 +243,45 @@ class GramVaaniDataSets:
|
||||
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
|
||||
train_size, dev_size, test_size = self._calculate_data_set_sizes()
|
||||
self.train = self.valid.loc[0:train_size]
|
||||
self.dev = self.valid.loc[train_size:train_size+dev_size]
|
||||
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
|
||||
self.dev = self.valid.loc[train_size : train_size + dev_size]
|
||||
self.test = self.valid.loc[
|
||||
train_size + dev_size : train_size + dev_size + test_size
|
||||
]
|
||||
|
||||
def _convert_csv_data_to_raw_data(self):
|
||||
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
|
||||
["audio_url","transcript","audio_length"]
|
||||
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
|
||||
self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
|
||||
["audio_url", "transcript", "audio_length"]
|
||||
].swifter.apply(
|
||||
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
|
||||
axis=1,
|
||||
raw=True,
|
||||
)
|
||||
self.raw.reset_index()
|
||||
|
||||
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
|
||||
if audio_url == "audio_url":
|
||||
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
||||
mp3_filename = os.path.basename(audio_url)
|
||||
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
|
||||
wav_relative_filename = os.path.join(
|
||||
"wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
|
||||
)
|
||||
wav_filesize = os.path.getsize(
|
||||
os.path.join(self.target_dir, wav_relative_filename)
|
||||
)
|
||||
transcript = validate_label(transcript)
|
||||
if None == transcript:
|
||||
transcript = ""
|
||||
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
||||
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
||||
|
||||
def _is_valid_raw_rows(self):
|
||||
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
|
||||
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
|
||||
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
|
||||
is_valid_raw_row = [
|
||||
(is_valid_raw_transcript & is_valid_raw_wav_frame)
|
||||
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
|
||||
is_valid_raw_transcripts, is_valid_raw_wav_frames
|
||||
)
|
||||
]
|
||||
series = pd.Series(is_valid_raw_row)
|
||||
return series
|
||||
|
||||
@ -252,16 +290,29 @@ class GramVaaniDataSets:
|
||||
|
||||
def _is_valid_raw_wav_frames(self):
|
||||
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
||||
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
||||
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
||||
wav_filepaths = [
|
||||
os.path.join(self.target_dir, str(wav_filename))
|
||||
for wav_filename in self.raw.wav_filename
|
||||
]
|
||||
wav_frames = [
|
||||
int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
for wav_filepath in wav_filepaths
|
||||
]
|
||||
is_valid_raw_wav_frames = [
|
||||
self._is_wav_frame_valid(wav_frame, transcript)
|
||||
for wav_frame, transcript in zip(wav_frames, transcripts)
|
||||
]
|
||||
return pd.Series(is_valid_raw_wav_frames)
|
||||
|
||||
def _is_wav_frame_valid(self, wav_frame, transcript):
|
||||
is_wav_frame_valid = True
|
||||
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
|
||||
if int(wav_frame / SAMPLE_RATE * 1000 / 10 / 2) < len(str(transcript)):
|
||||
is_wav_frame_valid = False
|
||||
elif wav_frame/SAMPLE_RATE > MAX_SECS:
|
||||
elif wav_frame / SAMPLE_RATE > MAX_SECS:
|
||||
is_wav_frame_valid = False
|
||||
return is_wav_frame_valid
|
||||
|
||||
@ -280,7 +331,14 @@ class GramVaaniDataSets:
|
||||
def _save(self, dataset):
|
||||
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
|
||||
dataframe = getattr(self, dataset)
|
||||
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
|
||||
dataframe.to_csv(
|
||||
dataset_path,
|
||||
index=False,
|
||||
encoding="utf-8",
|
||||
escapechar="\\",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main entry point allowing external calls
|
||||
@ -304,4 +362,5 @@ def main(args):
|
||||
datasets.save()
|
||||
_logger.info("Finished GramVaani importer...")
|
||||
|
||||
|
||||
main(sys.argv[1:])
|
||||
|
@ -1,28 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
import sys
|
||||
|
||||
import pandas
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
LDC93S1_BASE = "LDC93S1"
|
||||
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
|
||||
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav")
|
||||
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt")
|
||||
local_file = maybe_download(
|
||||
LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
|
||||
)
|
||||
trans_file = maybe_download(
|
||||
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
|
||||
)
|
||||
with open(trans_file, "r") as fin:
|
||||
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '')
|
||||
transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
|
||||
".", ""
|
||||
)
|
||||
|
||||
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
|
||||
columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
df = pandas.DataFrame(
|
||||
data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
|
||||
columns=["wav_filename", "wav_filesize", "transcript"],
|
||||
)
|
||||
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,33 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import pandas
|
||||
import progressbar
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
|
||||
import pandas
|
||||
import progressbar
|
||||
from sox import Transformer
|
||||
from util.downloader import maybe_download
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data to data_dir
|
||||
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir))
|
||||
print(
|
||||
"Downloading Librivox data set (55GB) into {} if not already present...".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz"
|
||||
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz"
|
||||
TRAIN_CLEAN_100_URL = (
|
||||
"http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||
)
|
||||
TRAIN_CLEAN_360_URL = (
|
||||
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
|
||||
)
|
||||
TRAIN_OTHER_500_URL = (
|
||||
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
|
||||
)
|
||||
|
||||
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
|
||||
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
|
||||
@ -35,12 +41,20 @@ def _download_and_preprocess_data(data_dir):
|
||||
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
|
||||
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
|
||||
|
||||
def filename_of(x): return os.path.split(x)[1]
|
||||
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL)
|
||||
def filename_of(x):
|
||||
return os.path.split(x)[1]
|
||||
|
||||
train_clean_100 = maybe_download(
|
||||
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
|
||||
)
|
||||
bar.update(0)
|
||||
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL)
|
||||
train_clean_360 = maybe_download(
|
||||
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
|
||||
)
|
||||
bar.update(1)
|
||||
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL)
|
||||
train_other_500 = maybe_download(
|
||||
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
|
||||
@ -48,9 +62,13 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
|
||||
bar.update(4)
|
||||
|
||||
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL)
|
||||
test_clean = maybe_download(
|
||||
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
|
||||
)
|
||||
bar.update(5)
|
||||
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL)
|
||||
test_other = maybe_download(
|
||||
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
|
||||
)
|
||||
bar.update(6)
|
||||
|
||||
# Conditionally extract LibriSpeech data
|
||||
@ -61,11 +79,17 @@ def _download_and_preprocess_data(data_dir):
|
||||
LIBRIVOX_DIR = "LibriSpeech"
|
||||
work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
|
||||
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
|
||||
)
|
||||
bar.update(0)
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
|
||||
)
|
||||
bar.update(1)
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
|
||||
@ -91,28 +115,48 @@ def _download_and_preprocess_data(data_dir):
|
||||
# data_dir/LibriSpeech/split-wav/1-2-2.txt
|
||||
# ...
|
||||
print("Converting FLAC to WAV and splitting transcriptions...")
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav")
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
train_100 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-clean-100", "train-clean-100-wav"
|
||||
)
|
||||
bar.update(0)
|
||||
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav")
|
||||
train_360 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-clean-360", "train-clean-360-wav"
|
||||
)
|
||||
bar.update(1)
|
||||
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav")
|
||||
train_500 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-other-500", "train-other-500-wav"
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav")
|
||||
dev_clean = _convert_audio_and_split_sentences(
|
||||
work_dir, "dev-clean", "dev-clean-wav"
|
||||
)
|
||||
bar.update(3)
|
||||
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav")
|
||||
dev_other = _convert_audio_and_split_sentences(
|
||||
work_dir, "dev-other", "dev-other-wav"
|
||||
)
|
||||
bar.update(4)
|
||||
|
||||
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav")
|
||||
test_clean = _convert_audio_and_split_sentences(
|
||||
work_dir, "test-clean", "test-clean-wav"
|
||||
)
|
||||
bar.update(5)
|
||||
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav")
|
||||
test_other = _convert_audio_and_split_sentences(
|
||||
work_dir, "test-other", "test-other-wav"
|
||||
)
|
||||
bar.update(6)
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False)
|
||||
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False)
|
||||
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False)
|
||||
train_100.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
|
||||
)
|
||||
train_360.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
|
||||
)
|
||||
train_500.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
|
||||
)
|
||||
|
||||
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
|
||||
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
|
||||
@ -120,6 +164,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
|
||||
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_extract(data_dir, extracted_data, archive):
|
||||
# If data_dir/extracted_data does not exist, extract archive in data_dir
|
||||
if not gfile.Exists(os.path.join(data_dir, extracted_data)):
|
||||
@ -127,6 +172,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
|
||||
tar.extractall(data_dir)
|
||||
tar.close()
|
||||
|
||||
|
||||
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
source_dir = os.path.join(extracted_dir, data_set)
|
||||
target_dir = os.path.join(extracted_dir, dest_dir)
|
||||
@ -149,20 +195,22 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
# We also convert the corresponding FLACs to WAV in the same pass
|
||||
files = []
|
||||
for root, dirnames, filenames in os.walk(source_dir):
|
||||
for filename in fnmatch.filter(filenames, '*.trans.txt'):
|
||||
for filename in fnmatch.filter(filenames, "*.trans.txt"):
|
||||
trans_filename = os.path.join(root, filename)
|
||||
with codecs.open(trans_filename, "r", "utf-8") as fin:
|
||||
for line in fin:
|
||||
# Parse each segment line
|
||||
first_space = line.find(" ")
|
||||
seqid, transcript = line[:first_space], line[first_space+1:]
|
||||
seqid, transcript = line[:first_space], line[first_space + 1 :]
|
||||
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
transcript = transcript.lower().strip()
|
||||
|
||||
@ -177,7 +225,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
|
||||
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,44 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sox
|
||||
import zipfile
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import zipfile
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
ARCHIVE_DIR_NAME = 'lingua_libre'
|
||||
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip'
|
||||
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "lingua_libre"
|
||||
ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
|
||||
ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -46,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Produce CSV files and convert ogg data to wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -58,57 +54,70 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % archive_path)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
ogg_filename = sample[0]
|
||||
# Storing wav files next to the ogg ones - just with a different suffix
|
||||
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
|
||||
wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(ogg_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
|
||||
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
|
||||
glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
record_file = record.replace(ogg_root_dir + os.path.sep, '')
|
||||
record_file = record.replace(ogg_root_dir + os.path.sep, "")
|
||||
if record_filter(record_file):
|
||||
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
|
||||
samples.append(
|
||||
(
|
||||
os.path.join(ogg_root_dir, record_file),
|
||||
os.path.splitext(os.path.basename(record_file))[0],
|
||||
)
|
||||
)
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -139,7 +148,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
transcript = validate_label(item[2])
|
||||
if not transcript:
|
||||
continue
|
||||
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav'))
|
||||
wav_filename = os.path.join(
|
||||
ogg_root_dir, item[0].replace(".ogg", ".wav")
|
||||
)
|
||||
i_mod = i % 10
|
||||
if i_mod == 0:
|
||||
writer = test_writer
|
||||
@ -147,38 +158,63 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(ogg_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
transformer.build(ogg_filename, wav_filename)
|
||||
except sox.core.SoxError as ex:
|
||||
print('SoX processing error', ex, ogg_filename, wav_filename)
|
||||
print("SoX processing error", ex, ogg_filename, wav_filename)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
|
||||
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
|
||||
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--qId", type=int, required=True, help="LinguaLibre language qId"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--english-name", type=str, required=True, help="Enligh name of the language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bogus-records",
|
||||
type=argparse.FileType("r"),
|
||||
required=False,
|
||||
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
@ -191,15 +227,17 @@ if __name__ == "__main__":
|
||||
|
||||
def record_filter(path):
|
||||
if any(regex.match(path) for regex in bogus_regexes):
|
||||
print('Reject', path)
|
||||
print("Reject", path)
|
||||
return False
|
||||
return True
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
@ -208,6 +246,14 @@ if __name__ == "__main__":
|
||||
label = None
|
||||
return label
|
||||
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
|
||||
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(
|
||||
qId=CLI_ARGS.qId,
|
||||
iso639_3=CLI_ARGS.iso639_3,
|
||||
language_English_name=CLI_ARGS.english_name,
|
||||
)
|
||||
ARCHIVE_URL = ARCHIVE_URL.format(
|
||||
qId=CLI_ARGS.qId,
|
||||
iso639_3=CLI_ARGS.iso639_3,
|
||||
language_English_name=CLI_ARGS.english_name,
|
||||
)
|
||||
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)
|
||||
|
@ -1,43 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import os
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import tarfile
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import unicodedata
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
import progressbar
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
|
||||
ARCHIVE_DIR_NAME = '{language}'
|
||||
ARCHIVE_NAME = '{language}.tgz'
|
||||
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "{language}"
|
||||
ARCHIVE_NAME = "{language}.tgz"
|
||||
ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -48,8 +42,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -65,9 +59,13 @@ def one_sample(sample):
|
||||
wav_filename = sample[0]
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
rows = []
|
||||
@ -75,27 +73,30 @@ def one_sample(sample):
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
print("conversion failure", wav_filename)
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
@ -103,14 +104,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv')
|
||||
glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop
|
||||
if any(
|
||||
map(lambda sk: sk in record, SKIP_LIST)
|
||||
): # pylint: disable=cell-var-from-loop
|
||||
continue
|
||||
with open(record, 'r') as rec:
|
||||
with open(record, "r") as rec:
|
||||
for re in rec.readlines():
|
||||
re = re.strip().split('|')
|
||||
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav')
|
||||
re = re.strip().split("|")
|
||||
audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
|
||||
transcript = re[2]
|
||||
samples.append((audio, transcript))
|
||||
|
||||
@ -129,9 +132,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -151,39 +154,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=os.path.relpath(wav_filename, extracted_dir),
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=os.path.relpath(wav_filename, extracted_dir),
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated')
|
||||
parser.add_argument('--language', required=True, type=str, help='Dataset language to use')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skiplist",
|
||||
type=str,
|
||||
default="",
|
||||
help="Directories / books to skip, comma separated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language", required=True, type=str, help="Dataset language to use"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
|
||||
validate_label = get_validate_label(CLI_ARGS)
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -1,30 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
import wave
|
||||
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
|
||||
def is_file_truncated(wav_filename, wav_filesize):
|
||||
with wave.open(wav_filename, mode='rb') as fin:
|
||||
with wave.open(wav_filename, mode="rb") as fin:
|
||||
assert fin.getframerate() == 16000
|
||||
assert fin.getsampwidth() == 2
|
||||
assert fin.getnchannels() == 1
|
||||
@ -37,8 +31,13 @@ def is_file_truncated(wav_filename, wav_filesize):
|
||||
|
||||
def preprocess_data(folder_with_archives, target_dir):
|
||||
# First extract subset archives
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir)
|
||||
for subset in ("train", "dev", "test"):
|
||||
extract(
|
||||
os.path.join(
|
||||
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
|
||||
),
|
||||
target_dir,
|
||||
)
|
||||
|
||||
# Folder structure is now:
|
||||
# - magicdata_{train,dev,test}.tar.gz
|
||||
@ -54,58 +53,73 @@ def preprocess_data(folder_with_archives, target_dir):
|
||||
# name, one containing the speaker ID, and one containing the transcription
|
||||
|
||||
def load_set(set_path):
|
||||
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0)
|
||||
glob_path = os.path.join(set_path, '*', '*.wav')
|
||||
transcripts = pandas.read_csv(
|
||||
os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
|
||||
)
|
||||
glob_path = os.path.join(set_path, "*", "*.wav")
|
||||
set_files = []
|
||||
for wav in glob.glob(glob_path):
|
||||
try:
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.basename(wav)
|
||||
transcript = transcripts.loc[transcript_key, 'Transcription']
|
||||
transcript = transcripts.loc[transcript_key, "Transcription"]
|
||||
|
||||
# Some files in this dataset are truncated, the header duration
|
||||
# doesn't match the file size. This causes errors at training
|
||||
# time, so check here if things are fine before including a file
|
||||
if is_file_truncated(wav_filename, wav_filesize):
|
||||
print('Warning: File {} is corrupted, header duration does '
|
||||
'not match file size. Ignoring.'.format(wav_filename))
|
||||
print(
|
||||
"Warning: File {} is corrupted, header duration does "
|
||||
"not match file size. Ignoring.".format(wav_filename)
|
||||
)
|
||||
continue
|
||||
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(os.path.join(target_dir, subset))
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
|
||||
|
||||
# Trim train set to under 10s
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
|
||||
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]')
|
||||
df = df[~with_noise]
|
||||
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise)))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
|
||||
df = df[~with_noise]
|
||||
print(
|
||||
"Trimming {} samples with noise ([FIL] or [SPK])".format(
|
||||
sum(with_noise)
|
||||
)
|
||||
)
|
||||
|
||||
dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://openslr.org/68/
|
||||
parser = get_importers_parser(description='Import MAGICDATA corpus')
|
||||
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
|
||||
parser = get_importers_parser(description="Import MAGICDATA corpus")
|
||||
parser.add_argument(
|
||||
"folder_with_archives",
|
||||
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata')
|
||||
params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
|
||||
|
||||
preprocess_data(params.folder_with_archives, params.target_dir)
|
||||
|
||||
|
@ -1,25 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -27,7 +21,7 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1')
|
||||
main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
|
||||
|
||||
# Folder structure is now:
|
||||
# - primewords_md_2018_set1/
|
||||
@ -35,14 +29,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
# - [0-f]/[00-0f]/*.wav
|
||||
# - set1_transcript.json
|
||||
|
||||
transcripts_path = os.path.join(main_folder, 'set1_transcript.json')
|
||||
transcripts_path = os.path.join(main_folder, "set1_transcript.json")
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = json.load(fin)
|
||||
|
||||
transcripts = {
|
||||
entry['file']: entry['text']
|
||||
for entry in transcripts
|
||||
}
|
||||
transcripts = {entry["file"]: entry["text"] for entry in transcripts}
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -54,13 +45,13 @@ def preprocess_data(tgz_file, target_dir):
|
||||
transcript = transcripts[transcript_key]
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
# Load all files, then deterministically split into train/dev/test sets
|
||||
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav'))
|
||||
all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
|
||||
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
|
||||
df.sort_values(by='wav_filename', inplace=True)
|
||||
df.sort_values(by="wav_filename", inplace=True)
|
||||
|
||||
indices = np.arange(0, len(df))
|
||||
np.random.seed(12345)
|
||||
@ -73,29 +64,33 @@ def preprocess_data(tgz_file, target_dir):
|
||||
train_indices = indices[:-10000]
|
||||
|
||||
train_files = df.iloc[train_indices]
|
||||
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
|
||||
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
|
||||
train_files = train_files[durations <= 15.0]
|
||||
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, 'primewords_train.csv')
|
||||
print('Saving train set into {}...'.format(dest_csv))
|
||||
print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, "primewords_train.csv")
|
||||
print("Saving train set into {}...".format(dest_csv))
|
||||
train_files.to_csv(dest_csv, index=False)
|
||||
|
||||
dev_files = df.iloc[dev_indices]
|
||||
dest_csv = os.path.join(target_dir, 'primewords_dev.csv')
|
||||
print('Saving dev set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "primewords_dev.csv")
|
||||
print("Saving dev set into {}...".format(dest_csv))
|
||||
dev_files.to_csv(dest_csv, index=False)
|
||||
|
||||
test_files = df.iloc[test_indices]
|
||||
dest_csv = os.path.join(target_dir, 'primewords_test.csv')
|
||||
print('Saving test set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "primewords_test.csv")
|
||||
print("Saving test set into {}...".format(dest_csv))
|
||||
test_files.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/47/
|
||||
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
|
||||
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
|
||||
parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -1,45 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sox
|
||||
import zipfile
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import tarfile
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import unicodedata
|
||||
import zipfile
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
from util.helpers import secs_to_hours
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
|
||||
ARCHIVE_DIR_NAME = 'African_Accented_French'
|
||||
ARCHIVE_NAME = 'African_Accented_French.tar.gz'
|
||||
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "African_Accented_French"
|
||||
ARCHIVE_NAME = "African_Accented_French.tar.gz"
|
||||
ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -47,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Produce CSV files
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -60,81 +55,89 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % archive_path)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
wav_filename = sample[0]
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
rows = []
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
wav_root_dir = os.path.join(extracted_dir)
|
||||
|
||||
all_files = [
|
||||
'transcripts/train/yaounde/fn_text.txt',
|
||||
'transcripts/train/ca16_conv/transcripts.txt',
|
||||
'transcripts/train/ca16_read/conditioned.txt',
|
||||
'transcripts/dev/niger_west_african_fr/transcripts.txt',
|
||||
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv',
|
||||
'transcripts/devtest/ca16_read/conditioned.txt',
|
||||
'transcripts/test/ca16/prompts.txt',
|
||||
"transcripts/train/yaounde/fn_text.txt",
|
||||
"transcripts/train/ca16_conv/transcripts.txt",
|
||||
"transcripts/train/ca16_read/conditioned.txt",
|
||||
"transcripts/dev/niger_west_african_fr/transcripts.txt",
|
||||
"speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
|
||||
"transcripts/devtest/ca16_read/conditioned.txt",
|
||||
"transcripts/test/ca16/prompts.txt",
|
||||
]
|
||||
|
||||
transcripts = {}
|
||||
for tr in all_files:
|
||||
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source:
|
||||
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
|
||||
for line in tr_source.readlines():
|
||||
line = line.strip()
|
||||
|
||||
if '.tsv' in tr:
|
||||
sep = ' '
|
||||
if ".tsv" in tr:
|
||||
sep = " "
|
||||
else:
|
||||
sep = ' '
|
||||
sep = " "
|
||||
|
||||
audio = os.path.basename(line.split(sep)[0])
|
||||
|
||||
if not ('.wav' in audio):
|
||||
if '.tdf' in audio:
|
||||
audio = audio.replace('.tdf', '.wav')
|
||||
if not (".wav" in audio):
|
||||
if ".tdf" in audio:
|
||||
audio = audio.replace(".tdf", ".wav")
|
||||
else:
|
||||
audio += '.wav'
|
||||
audio += ".wav"
|
||||
|
||||
transcript = ' '.join(line.split(sep)[1:])
|
||||
transcript = " ".join(line.split(sep)[1:])
|
||||
transcripts[audio] = transcript
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(wav_root_dir, '**/*.wav')
|
||||
glob_dir = os.path.join(wav_root_dir, "**/*.wav")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
record_file = os.path.basename(record)
|
||||
if record_file in transcripts:
|
||||
@ -156,9 +159,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -178,25 +181,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
@ -204,9 +220,11 @@ if __name__ == "__main__":
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -1,44 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
|
||||
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
|
||||
# ./data/swb/swb1_LDC97S62.tgz
|
||||
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import pandas
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import codecs
|
||||
import tarfile
|
||||
import requests
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
import pandas
|
||||
import requests
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
|
||||
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
|
||||
ARCHIVE_URL = 'http://www.openslr.org/resources/5/'
|
||||
ARCHIVE_DIR_NAME = 'LDC97S62'
|
||||
LDC_DATASET = 'swb1_LDC97S62.tgz'
|
||||
ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
|
||||
ARCHIVE_URL = "http://www.openslr.org/resources/5/"
|
||||
ARCHIVE_DIR_NAME = "LDC97S62"
|
||||
LDC_DATASET = "swb1_LDC97S62.tgz"
|
||||
|
||||
|
||||
def download_file(folder, url):
|
||||
# https://stackoverflow.com/a/16696317/738515
|
||||
local_filename = url.split('/')[-1]
|
||||
local_filename = url.split("/")[-1]
|
||||
full_filename = os.path.join(folder, local_filename)
|
||||
r = requests.get(url, stream=True)
|
||||
with open(full_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
with open(full_filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
return full_filename
|
||||
|
||||
@ -46,7 +40,7 @@ def download_file(folder, url):
|
||||
def maybe_download(archive_url, target_dir, ldc_dataset):
|
||||
# If archive file does not exist, download it...
|
||||
archive_path = os.path.join(target_dir, ldc_dataset)
|
||||
ldc_path = archive_url+ldc_dataset
|
||||
ldc_path = archive_url + ldc_dataset
|
||||
if not os.path.exists(target_dir):
|
||||
print('No path "%s" - creating ...' % target_dir)
|
||||
makedirs(target_dir)
|
||||
@ -65,17 +59,23 @@ def _download_and_preprocess_data(data_dir):
|
||||
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
|
||||
|
||||
# Check swb1_LDC97S62.tgz then extract
|
||||
assert(os.path.isfile(archive_path))
|
||||
assert os.path.isfile(archive_path)
|
||||
_extract(target_dir, archive_path)
|
||||
|
||||
|
||||
# Transcripts
|
||||
transcripts_path = maybe_download(ARCHIVE_URL, target_dir, ARCHIVE_NAME)
|
||||
_extract(target_dir, transcripts_path)
|
||||
|
||||
# Check swb1_d1/2/3/4/swb_ms98_transcriptions
|
||||
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"]
|
||||
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders]))
|
||||
|
||||
expected_folders = [
|
||||
"swb1_d1",
|
||||
"swb1_d2",
|
||||
"swb1_d3",
|
||||
"swb1_d4",
|
||||
"swb_ms98_transcriptions",
|
||||
]
|
||||
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
|
||||
|
||||
# Conditionally convert swb sph data to wav
|
||||
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
|
||||
_maybe_convert_wav(target_dir, "swb1_d2", "swb1_d2-wav")
|
||||
@ -83,13 +83,21 @@ def _download_and_preprocess_data(data_dir):
|
||||
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
|
||||
|
||||
# Conditionally split wav data
|
||||
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav")
|
||||
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav")
|
||||
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav")
|
||||
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav")
|
||||
|
||||
d1 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
|
||||
)
|
||||
d2 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
|
||||
)
|
||||
d3 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
|
||||
)
|
||||
d4 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
|
||||
)
|
||||
|
||||
swb_files = d1.append(d2).append(d3).append(d4)
|
||||
|
||||
|
||||
train_files, dev_files, test_files = _split_sets(swb_files)
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
@ -97,7 +105,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(target_dir, "swb-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(target_dir, "swb-test.csv"), index=False)
|
||||
|
||||
|
||||
|
||||
def _extract(target_dir, archive_path):
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
@ -118,25 +126,46 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
# Loop over sph files in source_dir and convert each to 16-bit PCM wav
|
||||
for root, dirnames, filenames in os.walk(source_dir):
|
||||
for filename in fnmatch.filter(filenames, "*.sph"):
|
||||
for channel in ['1', '2']:
|
||||
for channel in ["1", "2"]:
|
||||
sph_file = os.path.join(root, filename)
|
||||
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav"
|
||||
wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(target_dir, wav_filename)
|
||||
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav"
|
||||
temp_wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ "-temp.wav"
|
||||
)
|
||||
temp_wav_file = os.path.join(target_dir, temp_wav_filename)
|
||||
print("converting {} to {}".format(sph_file, temp_wav_file))
|
||||
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file])
|
||||
subprocess.check_call(
|
||||
[
|
||||
"sph2pipe",
|
||||
"-c",
|
||||
channel,
|
||||
"-p",
|
||||
"-f",
|
||||
"rif",
|
||||
sph_file,
|
||||
temp_wav_file,
|
||||
]
|
||||
)
|
||||
print("upsampling {} to {}".format(temp_wav_file, wav_file))
|
||||
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
|
||||
soundfile.write(wav_file, audioData, frameRate, "PCM_16")
|
||||
os.remove(temp_wav_file)
|
||||
|
||||
|
||||
|
||||
def _parse_transcriptions(trans_file):
|
||||
segments = []
|
||||
with codecs.open(trans_file, "r", "utf-8") as fin:
|
||||
for line in fin:
|
||||
if line.startswith("#") or len(line) <= 1:
|
||||
if line.startswith("#") or len(line) <= 1:
|
||||
continue
|
||||
|
||||
tokens = line.split()
|
||||
@ -150,15 +179,19 @@ def _parse_transcriptions(trans_file):
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
segments.append({
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"transcript": transcript,
|
||||
})
|
||||
segments.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
return segments
|
||||
|
||||
|
||||
@ -183,8 +216,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
|
||||
segments = _parse_transcriptions(trans_file)
|
||||
|
||||
# Open wav corresponding to transcription file
|
||||
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A']
|
||||
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav"
|
||||
channel = ("2", "1")[
|
||||
(os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
|
||||
]
|
||||
wav_filename = (
|
||||
"sw0"
|
||||
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(source_dir, wav_filename)
|
||||
|
||||
print("splitting {} according to {}".format(wav_file, trans_file))
|
||||
@ -200,26 +241,39 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
|
||||
# Create wav segment filename
|
||||
start_time = segment["start_time"]
|
||||
stop_time = segment["stop_time"]
|
||||
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(
|
||||
start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
if _is_wav_too_short(new_wav_filename):
|
||||
continue
|
||||
continue
|
||||
new_wav_file = os.path.join(target_dir, new_wav_filename)
|
||||
|
||||
_split_wav(origAudio, start_time, stop_time, new_wav_file)
|
||||
|
||||
new_wav_filesize = os.path.getsize(new_wav_file)
|
||||
transcript = segment["transcript"]
|
||||
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
|
||||
files.append(
|
||||
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
|
||||
)
|
||||
|
||||
# Close origAudio
|
||||
origAudio.close()
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
def _is_wav_too_short(wav_filename):
|
||||
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav']
|
||||
short_wav_filenames = [
|
||||
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
|
||||
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
|
||||
]
|
||||
return wav_filename in short_wav_filenames
|
||||
|
||||
|
||||
@ -234,7 +288,7 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
chunkAudio.writeframes(chunkData)
|
||||
chunkAudio.close()
|
||||
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
@ -248,10 +302,24 @@ def _split_sets(filelist):
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
|
||||
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end])
|
||||
return (
|
||||
filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0):
|
||||
def _read_data_set(
|
||||
filelist,
|
||||
thread_count,
|
||||
batch_size,
|
||||
numcep,
|
||||
numcontext,
|
||||
stride=1,
|
||||
offset=0,
|
||||
next_index=lambda i: i + 1,
|
||||
limit=0,
|
||||
):
|
||||
# Optionally apply dataset size limit
|
||||
if limit > 0:
|
||||
filelist = filelist.iloc[:limit]
|
||||
@ -259,7 +327,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
|
||||
filelist = filelist[offset::stride]
|
||||
|
||||
# Return DataSet
|
||||
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index)
|
||||
return DataSet(
|
||||
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,70 +1,76 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
|
||||
Use "python3 import_swc.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
"""
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import re
|
||||
import csv
|
||||
import sox
|
||||
import wave
|
||||
import shutil
|
||||
import random
|
||||
import tarfile
|
||||
import argparse
|
||||
import progressbar
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from glob import glob
|
||||
from collections import Counter
|
||||
from glob import glob
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||
LANGUAGES = ['dutch', 'english', 'german']
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker']
|
||||
LANGUAGES = ["dutch", "english", "german"]
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
|
||||
CHANNELS = 1
|
||||
SAMPLE_RATE = 16000
|
||||
UNKNOWN = '<unknown>'
|
||||
AUDIO_PATTERN = 'audio*.ogg'
|
||||
WAV_NAME = 'audio.wav'
|
||||
ALIGNED_NAME = 'aligned.swc'
|
||||
UNKNOWN = "<unknown>"
|
||||
AUDIO_PATTERN = "audio*.ogg"
|
||||
WAV_NAME = "audio.wav"
|
||||
ALIGNED_NAME = "aligned.swc"
|
||||
|
||||
SUBSTITUTIONS = {
|
||||
'german': [
|
||||
(re.compile(r'\$'), 'dollar'),
|
||||
(re.compile(r'€'), 'euro'),
|
||||
(re.compile(r'£'), 'pfund'),
|
||||
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
|
||||
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
|
||||
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
|
||||
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
|
||||
(re.compile(r'eins punkt null null null'), 'ein tausend'),
|
||||
(re.compile(r'punkt null null null'), 'tausend'),
|
||||
(re.compile(r'punkt null'), None)
|
||||
"german": [
|
||||
(re.compile(r"\$"), "dollar"),
|
||||
(re.compile(r"€"), "euro"),
|
||||
(re.compile(r"£"), "pfund"),
|
||||
(
|
||||
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
|
||||
r"\1zehnhundert \2er ",
|
||||
),
|
||||
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
|
||||
(
|
||||
re.compile(
|
||||
r"eins punkt null null null punkt null null null punkt null null null"
|
||||
),
|
||||
"eine milliarde",
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"punkt null null null punkt null null null punkt null null null"
|
||||
),
|
||||
"milliarden",
|
||||
),
|
||||
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
|
||||
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
|
||||
(re.compile(r"eins punkt null null null"), "ein tausend"),
|
||||
(re.compile(r"punkt null null null"), "tausend"),
|
||||
(re.compile(r"punkt null"), None),
|
||||
]
|
||||
}
|
||||
|
||||
DONT_NORMALIZE = {
|
||||
'german': 'ÄÖÜäöüß'
|
||||
}
|
||||
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
|
||||
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
|
||||
|
||||
|
||||
class Sample:
|
||||
@ -98,11 +104,14 @@ def get_sample_size(population_size):
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
|
||||
(margin_of_error ** 2 * train_size)
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
@ -111,14 +120,16 @@ def get_sample_size(population_size):
|
||||
|
||||
def maybe_download_language(language):
|
||||
lang_upper = language[0].upper() + language[1:]
|
||||
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper))
|
||||
return maybe_download(
|
||||
SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper),
|
||||
)
|
||||
|
||||
|
||||
def maybe_extract(data_dir, extracted_data, archive):
|
||||
extracted = path.join(data_dir, extracted_data)
|
||||
if path.isdir(extracted):
|
||||
extracted = os.path.join(data_dir, extracted_data)
|
||||
if os.path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
@ -133,29 +144,29 @@ def maybe_extract(data_dir, extracted_data, archive):
|
||||
def ignored(node):
|
||||
if node is None:
|
||||
return False
|
||||
if node.tag == 'ignored':
|
||||
if node.tag == "ignored":
|
||||
return True
|
||||
return ignored(node.find('..'))
|
||||
return ignored(node.find(".."))
|
||||
|
||||
|
||||
def read_token(token):
|
||||
texts, start, end = [], None, None
|
||||
notes = token.findall('n')
|
||||
notes = token.findall("n")
|
||||
if len(notes) > 0:
|
||||
for note in notes:
|
||||
attributes = note.attrib
|
||||
if start is None and 'start' in attributes:
|
||||
start = int(attributes['start'])
|
||||
if 'end' in attributes:
|
||||
token_end = int(attributes['end'])
|
||||
if start is None and "start" in attributes:
|
||||
start = int(attributes["start"])
|
||||
if "end" in attributes:
|
||||
token_end = int(attributes["end"])
|
||||
if end is None or token_end > end:
|
||||
end = token_end
|
||||
if 'pronunciation' in attributes:
|
||||
t = attributes['pronunciation']
|
||||
if "pronunciation" in attributes:
|
||||
t = attributes["pronunciation"]
|
||||
texts.append(t)
|
||||
elif 'text' in token.attrib:
|
||||
texts.append(token.attrib['text'])
|
||||
return start, end, ' '.join(texts)
|
||||
elif "text" in token.attrib:
|
||||
texts.append(token.attrib["text"])
|
||||
return start, end, " ".join(texts)
|
||||
|
||||
|
||||
def in_alphabet(alphabet, c):
|
||||
@ -163,10 +174,12 @@ def in_alphabet(alphabet, c):
|
||||
|
||||
|
||||
ALPHABETS = {}
|
||||
|
||||
|
||||
def get_alphabet(language):
|
||||
if language in ALPHABETS:
|
||||
return ALPHABETS[language]
|
||||
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
|
||||
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
|
||||
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
||||
ALPHABETS[language] = alphabet
|
||||
return alphabet
|
||||
@ -176,27 +189,35 @@ def label_filter(label, language):
|
||||
label = label.translate(PRE_FILTER)
|
||||
label = validate_label(label)
|
||||
if label is None:
|
||||
return None, 'validation'
|
||||
return None, "validation"
|
||||
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
|
||||
for pattern, replacement in substitutions:
|
||||
if replacement is None:
|
||||
if pattern.match(label):
|
||||
return None, 'substitution rule'
|
||||
return None, "substitution rule"
|
||||
else:
|
||||
label = pattern.sub(replacement, label)
|
||||
chars = []
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
|
||||
alphabet = get_alphabet(language)
|
||||
for c in label:
|
||||
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in dont_normalize
|
||||
and not in_alphabet(alphabet, c)
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
for sc in c:
|
||||
if not in_alphabet(alphabet, sc):
|
||||
return None, 'illegal character'
|
||||
return None, "illegal character"
|
||||
chars.append(sc)
|
||||
label = ''.join(chars)
|
||||
label = "".join(chars)
|
||||
label = validate_label(label)
|
||||
return label, 'validation' if label is None else None
|
||||
return label, "validation" if label is None else None
|
||||
|
||||
|
||||
def collect_samples(base_dir, language):
|
||||
@ -207,7 +228,9 @@ def collect_samples(base_dir, language):
|
||||
samples = []
|
||||
reasons = Counter()
|
||||
|
||||
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'):
|
||||
def add_sample(
|
||||
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
|
||||
):
|
||||
if p_start is not None and p_end is not None and p_text is not None:
|
||||
duration = p_end - p_start
|
||||
text, filter_reason = label_filter(p_text, language)
|
||||
@ -217,53 +240,67 @@ def collect_samples(base_dir, language):
|
||||
p_reason = filter_reason
|
||||
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
|
||||
skip = True
|
||||
p_reason = 'unknown speaker'
|
||||
p_reason = "unknown speaker"
|
||||
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
|
||||
skip = True
|
||||
p_reason = 'unknown article'
|
||||
p_reason = "unknown article"
|
||||
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
||||
skip = True
|
||||
p_reason = 'exceeded duration'
|
||||
p_reason = "exceeded duration"
|
||||
elif int(duration / 30) < len(text):
|
||||
skip = True
|
||||
p_reason = 'too short to decode'
|
||||
p_reason = "too short to decode"
|
||||
elif duration / len(text) < 10:
|
||||
skip = True
|
||||
p_reason = 'length duration ratio'
|
||||
p_reason = "length duration ratio"
|
||||
if skip:
|
||||
reasons[p_reason] += 1
|
||||
else:
|
||||
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker))
|
||||
samples.append(
|
||||
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
|
||||
)
|
||||
elif p_start is None or p_end is None:
|
||||
reasons['missing timestamps'] += 1
|
||||
reasons["missing timestamps"] += 1
|
||||
else:
|
||||
reasons['missing text'] += 1
|
||||
reasons["missing text"] += 1
|
||||
|
||||
print('Collecting samples...')
|
||||
print("Collecting samples...")
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
for root in bar(roots):
|
||||
wav_path = path.join(root, WAV_NAME)
|
||||
wav_path = os.path.join(root, WAV_NAME)
|
||||
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
||||
article = UNKNOWN
|
||||
speaker = UNKNOWN
|
||||
for prop in aligned.iter('prop'):
|
||||
for prop in aligned.iter("prop"):
|
||||
attributes = prop.attrib
|
||||
if 'key' in attributes and 'value' in attributes:
|
||||
if attributes['key'] == 'DC.identifier':
|
||||
article = attributes['value']
|
||||
elif attributes['key'] == 'reader.name':
|
||||
speaker = attributes['value']
|
||||
for sentence in aligned.iter('s'):
|
||||
if "key" in attributes and "value" in attributes:
|
||||
if attributes["key"] == "DC.identifier":
|
||||
article = attributes["value"]
|
||||
elif attributes["key"] == "reader.name":
|
||||
speaker = attributes["value"]
|
||||
for sentence in aligned.iter("s"):
|
||||
if ignored(sentence):
|
||||
continue
|
||||
split = False
|
||||
tokens = list(map(read_token, sentence.findall('t')))
|
||||
tokens = list(map(read_token, sentence.findall("t")))
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
for token_start, token_end, token_text in tokens:
|
||||
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='has numbers')
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="has numbers",
|
||||
)
|
||||
sample_start, sample_end, token_texts, sample_texts = (
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
continue
|
||||
if sample_start is None:
|
||||
sample_start = token_start
|
||||
@ -271,20 +308,37 @@ def collect_samples(base_dir, language):
|
||||
continue
|
||||
token_texts.append(token_text)
|
||||
if token_end is not None:
|
||||
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='split')
|
||||
if (
|
||||
token_start != sample_start
|
||||
and token_end - sample_start > CLI_ARGS.max_duration > 0
|
||||
):
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="split",
|
||||
)
|
||||
sample_start = sample_end
|
||||
sample_texts = []
|
||||
split = True
|
||||
sample_end = token_end
|
||||
sample_texts.extend(token_texts)
|
||||
token_texts = []
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='split' if split else 'complete')
|
||||
print('Skipped samples:')
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="split" if split else "complete",
|
||||
)
|
||||
print("Skipped samples:")
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - {}: {}'.format(reason, n))
|
||||
print(" - {}: {}".format(reason, n))
|
||||
return samples
|
||||
|
||||
|
||||
@ -294,8 +348,8 @@ def maybe_convert_one_to_wav(entry):
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
combiner = sox.Combiner()
|
||||
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
output_wav = path.join(root, WAV_NAME)
|
||||
if path.isfile(output_wav):
|
||||
output_wav = os.path.join(root, WAV_NAME)
|
||||
if os.path.isfile(output_wav):
|
||||
return
|
||||
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
|
||||
try:
|
||||
@ -304,18 +358,18 @@ def maybe_convert_one_to_wav(entry):
|
||||
elif len(files) > 1:
|
||||
wav_files = []
|
||||
for i, file in enumerate(files):
|
||||
wav_path = path.join(root, 'audio{}.wav'.format(i))
|
||||
wav_path = os.path.join(root, "audio{}.wav".format(i))
|
||||
transformer.build(file, wav_path)
|
||||
wav_files.append(wav_path)
|
||||
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, 'concatenate')
|
||||
combiner.set_input_format(file_type=["wav"] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, "concatenate")
|
||||
except sox.core.SoxError:
|
||||
return
|
||||
|
||||
|
||||
def maybe_convert_to_wav(base_dir):
|
||||
roots = list(os.walk(base_dir))
|
||||
print('Converting and joining source audio files...')
|
||||
print("Converting and joining source audio files...")
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
tp = ThreadPool()
|
||||
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
|
||||
@ -335,53 +389,66 @@ def assign_sub_sets(samples):
|
||||
sample_set.extend(speakers.pop(0))
|
||||
train_set = sum(speakers, [])
|
||||
if len(train_set) == 0:
|
||||
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
|
||||
print(
|
||||
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
|
||||
)
|
||||
random.seed(42) # same source data == same output
|
||||
random.shuffle(samples)
|
||||
for index, sample in enumerate(samples):
|
||||
if index < sample_size:
|
||||
sample.sub_set = 'dev'
|
||||
sample.sub_set = "dev"
|
||||
elif index < 2 * sample_size:
|
||||
sample.sub_set = 'test'
|
||||
sample.sub_set = "test"
|
||||
else:
|
||||
sample.sub_set = 'train'
|
||||
sample.sub_set = "train"
|
||||
else:
|
||||
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
|
||||
for sub_set, sub_set_samples in [
|
||||
("train", train_set),
|
||||
("dev", sample_sets[0]),
|
||||
("test", sample_sets[1]),
|
||||
]:
|
||||
for sample in sub_set_samples:
|
||||
sample.sub_set = sub_set
|
||||
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
|
||||
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
|
||||
.format(sub_set, len(sub_set_samples), t))
|
||||
print(
|
||||
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
|
||||
sub_set, len(sub_set_samples), t
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_sample_dirs(language):
|
||||
print('Creating sample directories...')
|
||||
for set_name in ['train', 'dev', 'test']:
|
||||
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||
if not path.isdir(dir_path):
|
||||
print("Creating sample directories...")
|
||||
for set_name in ["train", "dev", "test"]:
|
||||
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
|
||||
if not os.path.isdir(dir_path):
|
||||
os.mkdir(dir_path)
|
||||
|
||||
|
||||
def split_audio_files(samples, language):
|
||||
print('Splitting audio files...')
|
||||
print("Splitting audio files...")
|
||||
sub_sets = Counter()
|
||||
src_wav_files = group(samples, lambda s: s.wav_path).items()
|
||||
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
|
||||
for wav_path, file_samples in bar(src_wav_files):
|
||||
file_samples = sorted(file_samples, key=lambda s: s.start)
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
with wave.open(wav_path, "r") as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
for sample in file_samples:
|
||||
index = sub_sets[sample.sub_set]
|
||||
sample_wav_path = path.join(CLI_ARGS.base_dir,
|
||||
language + '-' + sample.sub_set,
|
||||
'sample-{0:06d}.wav'.format(index))
|
||||
sample_wav_path = os.path.join(
|
||||
CLI_ARGS.base_dir,
|
||||
language + "-" + sample.sub_set,
|
||||
"sample-{0:06d}.wav".format(index),
|
||||
)
|
||||
sample.wav_path = sample_wav_path
|
||||
sub_sets[sample.sub_set] += 1
|
||||
src_wav_file.setpos(int(sample.start * rate / 1000.0))
|
||||
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
|
||||
with wave.open(sample_wav_path, 'w') as sample_wav_file:
|
||||
data = src_wav_file.readframes(
|
||||
int((sample.end - sample.start) * rate / 1000.0)
|
||||
)
|
||||
with wave.open(sample_wav_path, "w") as sample_wav_file:
|
||||
sample_wav_file.setnchannels(src_wav_file.getnchannels())
|
||||
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
|
||||
sample_wav_file.setframerate(rate)
|
||||
@ -391,22 +458,26 @@ def split_audio_files(samples, language):
|
||||
def write_csvs(samples, language):
|
||||
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
||||
base_dir = path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||
base_dir = os.path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
|
||||
with open(csv_path, "w") as csv_file:
|
||||
writer = csv.DictWriter(
|
||||
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
|
||||
)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
||||
bar = progressbar.ProgressBar(
|
||||
max_value=len(set_samples), widgets=SIMPLE_BAR
|
||||
)
|
||||
for sample in bar(set_samples):
|
||||
row = {
|
||||
'wav_filename': path.relpath(sample.wav_path, base_dir),
|
||||
'wav_filesize': path.getsize(sample.wav_path),
|
||||
'transcript': sample.text
|
||||
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
|
||||
"wav_filesize": os.path.getsize(sample.wav_path),
|
||||
"transcript": sample.text,
|
||||
}
|
||||
if CLI_ARGS.add_meta:
|
||||
row['article'] = sample.article
|
||||
row['speaker'] = sample.speaker
|
||||
row["article"] = sample.article
|
||||
row["speaker"] = sample.speaker
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
@ -414,8 +485,8 @@ def cleanup(archive, language):
|
||||
if not CLI_ARGS.keep_archive:
|
||||
print('Removing archive "{}"...'.format(archive))
|
||||
os.remove(archive)
|
||||
language_dir = path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
|
||||
language_dir = os.path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
|
||||
print('Removing intermediate files in "{}"...'.format(language_dir))
|
||||
shutil.rmtree(language_dir)
|
||||
|
||||
@ -433,34 +504,75 @@ def prepare_language(language):
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
|
||||
parser.add_argument('--exclude_numbers', type=bool, default=True,
|
||||
help='If sequences with non-transliterated numbers should be excluded')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--ignore_too_long', type=bool, default=False,
|
||||
help='If samples exceeding max_duration should be removed')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
|
||||
parser.add_argument("base_dir", help="Directory containing all data")
|
||||
parser.add_argument(
|
||||
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_numbers",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If sequences with non-transliterated numbers should be excluded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum sample duration in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_too_long",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If samples exceeding max_duration should be removed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
for language in LANGUAGES:
|
||||
parser.add_argument('--{}_alphabet'.format(language),
|
||||
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
|
||||
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns')
|
||||
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers')
|
||||
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles')
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser.add_argument('--keep_intermediate', type=bool, default=False,
|
||||
help='If intermediate files should be kept')
|
||||
parser.add_argument(
|
||||
"--{}_alphabet".format(language),
|
||||
help="Exclude {} samples with characters not in provided alphabet file".format(
|
||||
language
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_unknown_speakers",
|
||||
action="store_true",
|
||||
help="Exclude unknown speakers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_unknown_articles",
|
||||
action="store_true",
|
||||
help="Exclude unknown articles",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_archive",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If downloaded archives should be kept",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_intermediate",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If intermediate files should be kept",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
if CLI_ARGS.language == 'all':
|
||||
if CLI_ARGS.language == "all":
|
||||
for lang in LANGUAGES:
|
||||
prepare_language(lang)
|
||||
elif CLI_ARGS.language in LANGUAGES:
|
||||
prepare_language(CLI_ARGS.language)
|
||||
else:
|
||||
fail('Wrong language id')
|
||||
fail("Wrong language id")
|
||||
|
@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import pandas
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
|
||||
from glob import glob
|
||||
from os import makedirs, path, remove, rmdir
|
||||
|
||||
import pandas
|
||||
from sox import Transformer
|
||||
from util.downloader import maybe_download
|
||||
from tensorflow.python.platform import gfile
|
||||
from util.stm import parse_stm_file
|
||||
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from deepspeech_training.util.stm import parse_stm_file
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
@ -41,6 +35,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
|
||||
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_extract(data_dir, extracted_data, archive):
|
||||
# If data_dir/extracted_data does not exist, extract archive in data_dir
|
||||
if not gfile.Exists(path.join(data_dir, extracted_data)):
|
||||
@ -48,6 +43,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
|
||||
tar.extractall(data_dir)
|
||||
tar.close()
|
||||
|
||||
|
||||
def _maybe_convert_wav(data_dir, extracted_data):
|
||||
# Create extracted_data dir
|
||||
extracted_dir = path.join(data_dir, extracted_data)
|
||||
@ -61,6 +57,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
|
||||
# Conditionally convert test sph to wav
|
||||
_maybe_convert_wav_dataset(extracted_dir, "test")
|
||||
|
||||
|
||||
def _maybe_convert_wav_dataset(extracted_dir, data_set):
|
||||
# Create source dir
|
||||
source_dir = path.join(extracted_dir, data_set, "sph")
|
||||
@ -84,6 +81,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
|
||||
# Remove source_dir
|
||||
rmdir(source_dir)
|
||||
|
||||
|
||||
def _maybe_split_sentences(data_dir, extracted_data):
|
||||
# Create extracted_data dir
|
||||
extracted_dir = path.join(data_dir, extracted_data)
|
||||
@ -99,6 +97,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
|
||||
|
||||
return train_files, dev_files, test_files
|
||||
|
||||
|
||||
def _maybe_split_dataset(extracted_dir, data_set):
|
||||
# Create stm dir
|
||||
stm_dir = path.join(extracted_dir, data_set, "stm")
|
||||
@ -116,14 +115,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
|
||||
# Open wav corresponding to stm_file
|
||||
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
|
||||
wav_file = path.join(wav_dir, wav_filename)
|
||||
origAudio = wave.open(wav_file,'r')
|
||||
origAudio = wave.open(wav_file, "r")
|
||||
|
||||
# Loop over stm_segments and split wav_file for each segment
|
||||
for stm_segment in stm_segments:
|
||||
# Create wav segment filename
|
||||
start_time = stm_segment.start_time
|
||||
stop_time = stm_segment.stop_time
|
||||
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
path.splitext(path.basename(stm_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
new_wav_file = path.join(wav_dir, new_wav_filename)
|
||||
|
||||
# If the wav segment filename does not exist create it
|
||||
@ -131,23 +137,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
|
||||
_split_wav(origAudio, start_time, stop_time, new_wav_file)
|
||||
|
||||
new_wav_filesize = path.getsize(new_wav_file)
|
||||
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript))
|
||||
files.append(
|
||||
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
|
||||
)
|
||||
|
||||
# Close origAudio
|
||||
origAudio.close()
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
frameRate = origAudio.getframerate()
|
||||
origAudio.setpos(int(start_time*frameRate))
|
||||
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate))
|
||||
chunkAudio = wave.open(new_wav_file,'w')
|
||||
origAudio.setpos(int(start_time * frameRate))
|
||||
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
|
||||
chunkAudio = wave.open(new_wav_file, "w")
|
||||
chunkAudio.setnchannels(origAudio.getnchannels())
|
||||
chunkAudio.setsampwidth(origAudio.getsampwidth())
|
||||
chunkAudio.setframerate(frameRate)
|
||||
chunkAudio.writeframes(chunkData)
|
||||
chunkAudio.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
NAME : LDC TIMIT Dataset
|
||||
URL : https://catalog.ldc.upenn.edu/ldc93s1
|
||||
HOURS : 5
|
||||
@ -8,29 +8,32 @@
|
||||
AUTHORS : Garofolo, John, et al.
|
||||
TYPE : LDC Membership
|
||||
LICENCE : LDC User Agreement
|
||||
'''
|
||||
"""
|
||||
|
||||
import errno
|
||||
import fnmatch
|
||||
import os
|
||||
from os import path
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import fnmatch
|
||||
from os import path
|
||||
|
||||
import pandas as pd
|
||||
import subprocess
|
||||
|
||||
|
||||
def clean(word):
|
||||
# LC ALL & strip punctuation which are not required
|
||||
new = word.lower().replace('.', '')
|
||||
new = new.replace(',', '')
|
||||
new = new.replace(';', '')
|
||||
new = new.replace('"', '')
|
||||
new = new.replace('!', '')
|
||||
new = new.replace('?', '')
|
||||
new = new.replace(':', '')
|
||||
new = new.replace('-', '')
|
||||
new = word.lower().replace(".", "")
|
||||
new = new.replace(",", "")
|
||||
new = new.replace(";", "")
|
||||
new = new.replace('"', "")
|
||||
new = new.replace("!", "")
|
||||
new = new.replace("?", "")
|
||||
new = new.replace(":", "")
|
||||
new = new.replace("-", "")
|
||||
return new
|
||||
|
||||
|
||||
def _preprocess_data(args):
|
||||
|
||||
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
|
||||
@ -40,16 +43,24 @@ def _preprocess_data(args):
|
||||
|
||||
if ignoreSASentences:
|
||||
print("Using recommended ignore SA sentences")
|
||||
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)")
|
||||
print(
|
||||
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
|
||||
)
|
||||
else:
|
||||
print("Using unrecommended setting to include SA sentences")
|
||||
|
||||
datapath = args
|
||||
target = path.join(datapath, "TIMIT")
|
||||
print("Checking to see if data has already been extracted in given argument: %s", target)
|
||||
print(
|
||||
"Checking to see if data has already been extracted in given argument: %s",
|
||||
target,
|
||||
)
|
||||
|
||||
if not path.isdir(target):
|
||||
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath)
|
||||
print(
|
||||
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
|
||||
datapath,
|
||||
)
|
||||
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
|
||||
if path.isfile(filepath):
|
||||
print("File found, extracting")
|
||||
@ -103,40 +114,58 @@ def _preprocess_data(args):
|
||||
# if ignoreSAsentences we only want those without SA in the name
|
||||
# OR
|
||||
# if not ignoreSAsentences we want all to be added
|
||||
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences):
|
||||
if 'train' in full_wav.lower():
|
||||
if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
|
||||
not ignoreSASentences
|
||||
):
|
||||
if "train" in full_wav.lower():
|
||||
train_list_wavs.append(full_wav)
|
||||
train_list_trans.append(trans)
|
||||
train_list_size.append(wav_filesize)
|
||||
elif 'test' in full_wav.lower():
|
||||
elif "test" in full_wav.lower():
|
||||
test_list_wavs.append(full_wav)
|
||||
test_list_trans.append(trans)
|
||||
test_list_size.append(wav_filesize)
|
||||
else:
|
||||
raise IOError
|
||||
|
||||
a = {'wav_filename': train_list_wavs,
|
||||
'wav_filesize': train_list_size,
|
||||
'transcript': train_list_trans
|
||||
}
|
||||
a = {
|
||||
"wav_filename": train_list_wavs,
|
||||
"wav_filesize": train_list_size,
|
||||
"transcript": train_list_trans,
|
||||
}
|
||||
|
||||
c = {'wav_filename': test_list_wavs,
|
||||
'wav_filesize': test_list_size,
|
||||
'transcript': test_list_trans
|
||||
}
|
||||
c = {
|
||||
"wav_filename": test_list_wavs,
|
||||
"wav_filesize": test_list_size,
|
||||
"transcript": test_list_trans,
|
||||
}
|
||||
|
||||
all = {'wav_filename': train_list_wavs + test_list_wavs,
|
||||
'wav_filesize': train_list_size + test_list_size,
|
||||
'transcript': train_list_trans + test_list_trans
|
||||
}
|
||||
all = {
|
||||
"wav_filename": train_list_wavs + test_list_wavs,
|
||||
"wav_filesize": train_list_size + test_list_size,
|
||||
"transcript": train_list_trans + test_list_trans,
|
||||
}
|
||||
|
||||
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_all = pd.DataFrame(
|
||||
all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
df_train = pd.DataFrame(
|
||||
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
df_test = pd.DataFrame(
|
||||
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
|
||||
df_all.to_csv(
|
||||
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
df_train.to_csv(
|
||||
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
df_test.to_csv(
|
||||
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
|
||||
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
|
||||
if __name__ == "__main__":
|
||||
_preprocess_data(sys.argv[1])
|
||||
|
150
bin/import_ts.py
150
bin/import_ts.py
@ -1,52 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import unidecode
|
||||
import zipfile
|
||||
import sox
|
||||
import subprocess
|
||||
import progressbar
|
||||
|
||||
import zipfile
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from util.downloader import maybe_download
|
||||
import unidecode
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
ARCHIVE_NAME = '2019-04-11_fr_FR'
|
||||
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME
|
||||
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip'
|
||||
ARCHIVE_NAME = "2019-04-11_fr_FR"
|
||||
ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
|
||||
ARCHIVE_URL = (
|
||||
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
|
||||
archive_path = maybe_download(
|
||||
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
|
||||
)
|
||||
# Conditionally extract archive data
|
||||
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
|
||||
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible)
|
||||
_maybe_convert_sets(
|
||||
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
|
||||
)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -58,16 +59,20 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
orig_filename = sample['path']
|
||||
orig_filename = sample["path"]
|
||||
# Storing wav files next to the wav ones - just with a different suffix
|
||||
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
_maybe_convert_wav(orig_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = sample['text']
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = sample["text"]
|
||||
|
||||
rows = []
|
||||
|
||||
@ -75,40 +80,41 @@ def one_sample(sample):
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
|
||||
target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
|
||||
path_to_original_csv = os.path.join(extracted_dir, "data.csv")
|
||||
with open(path_to_original_csv) as csv_f:
|
||||
data = [
|
||||
d for d in csv.DictReader(csv_f, delimiter=',')
|
||||
if float(d['duration']) <= MAX_SECS
|
||||
d
|
||||
for d in csv.DictReader(csv_f, delimiter=",")
|
||||
if float(d["duration"]) <= MAX_SECS
|
||||
]
|
||||
|
||||
for line in data:
|
||||
line['path'] = os.path.join(extracted_dir, line['path'])
|
||||
line["path"] = os.path.join(extracted_dir, line["path"])
|
||||
|
||||
num_samples = len(data)
|
||||
rows = []
|
||||
@ -125,9 +131,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -136,7 +142,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
test_writer.writeheader()
|
||||
|
||||
for i, item in enumerate(rows):
|
||||
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
|
||||
transcript = validate_label(
|
||||
cleanup_transcript(
|
||||
item[2], english_compatible=english_compatible
|
||||
)
|
||||
)
|
||||
if not transcript:
|
||||
continue
|
||||
wav_filename = os.path.join(target_dir, extracted_data, item[0])
|
||||
@ -147,45 +157,53 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(orig_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
transformer.build(orig_filename, wav_filename)
|
||||
except sox.core.SoxError as ex:
|
||||
print('SoX processing error', ex, orig_filename, wav_filename)
|
||||
print("SoX processing error", ex, orig_filename, wav_filename)
|
||||
|
||||
|
||||
PUNCTUATIONS_REG = re.compile(r"[°\-,;!?.()\[\]*…—]")
|
||||
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}')
|
||||
MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
|
||||
|
||||
|
||||
def cleanup_transcript(text, english_compatible=False):
|
||||
text = text.replace('’', "'").replace('\u00A0', ' ')
|
||||
text = PUNCTUATIONS_REG.sub(' ', text)
|
||||
text = MULTIPLE_SPACES_REG.sub(' ', text)
|
||||
text = text.replace("’", "'").replace("\u00A0", " ")
|
||||
text = PUNCTUATIONS_REG.sub(" ", text)
|
||||
text = MULTIPLE_SPACES_REG.sub(" ", text)
|
||||
if english_compatible:
|
||||
text = unidecode.unidecode(text)
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
|
||||
parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--english-compatible",
|
||||
action="store_true",
|
||||
dest="english_compatible",
|
||||
help="Remove diactrics and other non-ascii chars.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,45 +1,40 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
|
||||
Use "python3 import_tuda.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import wave
|
||||
import tarfile
|
||||
"""
|
||||
import argparse
|
||||
import progressbar
|
||||
import csv
|
||||
import os
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from collections import Counter
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
TUDA_VERSION = 'v2'
|
||||
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
|
||||
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE)
|
||||
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE)
|
||||
import progressbar
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
TUDA_VERSION = "v2"
|
||||
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
|
||||
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
|
||||
TUDA_PACKAGE
|
||||
)
|
||||
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
|
||||
|
||||
CHANNELS = 1
|
||||
SAMPLE_WIDTH = 2
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def maybe_extract(archive):
|
||||
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||
if path.isdir(extracted):
|
||||
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||
if os.path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
@ -52,86 +47,100 @@ def maybe_extract(archive):
|
||||
|
||||
|
||||
def check_and_prepare_sentence(sentence):
|
||||
sentence = sentence.lower().replace('co2', 'c o zwei')
|
||||
sentence = sentence.lower().replace("co2", "c o zwei")
|
||||
chars = []
|
||||
for c in sentence:
|
||||
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in "äöüß"
|
||||
and (ALPHABET is None or not ALPHABET.has_char(c))
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
for sc in c:
|
||||
if ALPHABET is not None and not ALPHABET.has_char(c):
|
||||
return None
|
||||
chars.append(sc)
|
||||
return validate_label(''.join(chars))
|
||||
return validate_label("".join(chars))
|
||||
|
||||
|
||||
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
|
||||
try:
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
with wave.open(wav_path, "r") as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
channels = src_wav_file.getnchannels()
|
||||
sample_width = src_wav_file.getsampwidth()
|
||||
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
|
||||
if rate != SAMPLE_RATE:
|
||||
return False, 'wrong sample rate'
|
||||
return False, "wrong sample rate"
|
||||
if channels != CHANNELS:
|
||||
return False, 'wrong number of channels'
|
||||
return False, "wrong number of channels"
|
||||
if sample_width != SAMPLE_WIDTH:
|
||||
return False, 'wrong sample width'
|
||||
return False, "wrong sample width"
|
||||
if milliseconds / len(sentence) < 30:
|
||||
return False, 'too short'
|
||||
return False, "too short"
|
||||
if milliseconds > CLI_ARGS.max_duration > 0:
|
||||
return False, 'too long'
|
||||
return False, "too long"
|
||||
except wave.Error:
|
||||
return False, 'invalid wav file'
|
||||
return False, "invalid wav file"
|
||||
except EOFError:
|
||||
return False, 'premature EOF'
|
||||
return True, 'OK'
|
||||
return False, "premature EOF"
|
||||
return True, "OK"
|
||||
|
||||
|
||||
def write_csvs(extracted):
|
||||
sample_counter = 0
|
||||
reasons = Counter()
|
||||
for sub_set in ['train', 'dev', 'test']:
|
||||
set_path = path.join(extracted, sub_set)
|
||||
for sub_set in ["train", "dev", "test"]:
|
||||
set_path = os.path.join(extracted, sub_set)
|
||||
set_files = os.listdir(set_path)
|
||||
recordings = {}
|
||||
for file in set_files:
|
||||
if file.endswith('.xml'):
|
||||
if file.endswith(".xml"):
|
||||
recordings[file[:-4]] = []
|
||||
for file in set_files:
|
||||
if file.endswith('.wav') and '_' in file:
|
||||
prefix = file.split('_')[0]
|
||||
if file.endswith(".wav") and "_" in file:
|
||||
prefix = file.split("_")[0]
|
||||
if prefix in recordings:
|
||||
recordings[prefix].append(file)
|
||||
recordings = recordings.items()
|
||||
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
||||
csv_path = os.path.join(
|
||||
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
|
||||
)
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
with open(csv_path, "w") as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
set_dir = path.join(extracted, sub_set)
|
||||
set_dir = os.path.join(extracted, sub_set)
|
||||
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
|
||||
for prefix, wav_names in bar(recordings):
|
||||
xml_path = path.join(set_dir, prefix + '.xml')
|
||||
xml_path = os.path.join(set_dir, prefix + ".xml")
|
||||
meta = ET.parse(xml_path).getroot()
|
||||
sentence = list(meta.iter('cleaned_sentence'))[0].text
|
||||
sentence = list(meta.iter("cleaned_sentence"))[0].text
|
||||
sentence = check_and_prepare_sentence(sentence)
|
||||
if sentence is None:
|
||||
continue
|
||||
for wav_name in wav_names:
|
||||
sample_counter += 1
|
||||
wav_path = path.join(set_path, wav_name)
|
||||
wav_path = os.path.join(set_path, wav_name)
|
||||
keep, reason = check_wav_file(wav_path, sentence)
|
||||
if keep:
|
||||
writer.writerow({
|
||||
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
|
||||
'wav_filesize': path.getsize(wav_path),
|
||||
'transcript': sentence.lower()
|
||||
})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": os.path.relpath(
|
||||
wav_path, CLI_ARGS.base_dir
|
||||
),
|
||||
"wav_filesize": os.path.getsize(wav_path),
|
||||
"transcript": sentence.lower(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
reasons[reason] += 1
|
||||
if len(reasons.keys()) > 0:
|
||||
print('Excluded samples:')
|
||||
print("Excluded samples:")
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
|
||||
|
||||
@ -150,13 +159,29 @@ def download_and_prepare():
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file')
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
|
||||
parser.add_argument("base_dir", help="Directory containing all data")
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum sample duration in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_archive",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If downloaded archives should be kept",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,29 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
|
||||
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
|
||||
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], ".."))
|
||||
|
||||
from util.importers import get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import re
|
||||
from multiprocessing import Pool
|
||||
from zipfile import ZipFile
|
||||
|
||||
import librosa
|
||||
import progressbar
|
||||
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
from zipfile import ZipFile
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
print_import_report,
|
||||
)
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
@ -37,7 +30,7 @@ ARCHIVE_URL = (
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract common voice data
|
||||
@ -48,8 +41,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print(f"No directory {extracted_path} - extracting archive...")
|
||||
with ZipFile(archive_path, "r") as zipobj:
|
||||
# Extract all the contents of zip file in current directory
|
||||
@ -59,15 +52,17 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data, "wav48")
|
||||
txt_dir = path.join(target_dir, extracted_data, "txt")
|
||||
extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
|
||||
txt_dir = os.path.join(target_dir, extracted_data, "txt")
|
||||
|
||||
directory = os.path.expanduser(extracted_dir)
|
||||
srtd = len(sorted(os.listdir(directory)))
|
||||
all_samples = []
|
||||
|
||||
for target in sorted(os.listdir(directory)):
|
||||
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
|
||||
all_samples += _maybe_prepare_set(
|
||||
path.join(extracted_dir, os.path.split(target)[-1])
|
||||
)
|
||||
|
||||
num_samples = len(all_samples)
|
||||
print(f"Converting wav files to {SAMPLE_RATE}hz...")
|
||||
@ -81,6 +76,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
|
||||
_write_csv(extracted_dir, txt_dir, target_dir)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
if is_audio_file(sample):
|
||||
y, sr = librosa.load(sample, sr=16000)
|
||||
@ -103,6 +99,7 @@ def _maybe_prepare_set(target_csv):
|
||||
samples = new_samples
|
||||
return samples
|
||||
|
||||
|
||||
def _write_csv(extracted_dir, txt_dir, target_dir):
|
||||
print(f"Writing CSV file")
|
||||
dset_abs_path = extracted_dir
|
||||
@ -197,7 +194,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
|
||||
|
||||
|
||||
def is_audio_file(filepath):
|
||||
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
|
||||
return any(
|
||||
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,24 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import tarfile
|
||||
import pandas
|
||||
import re
|
||||
import unicodedata
|
||||
import tarfile
|
||||
import threading
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
from six.moves import urllib
|
||||
import unicodedata
|
||||
import urllib
|
||||
from glob import glob
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from os import makedirs, path
|
||||
|
||||
import pandas
|
||||
from bs4 import BeautifulSoup
|
||||
from tensorflow.python.platform import gfile
|
||||
from util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
"""The number of jobs to run in parallel"""
|
||||
NUM_PARALLEL = 8
|
||||
@ -26,8 +21,10 @@ NUM_PARALLEL = 8
|
||||
"""Lambda function returns the filename of a path"""
|
||||
filename_of = lambda x: path.split(x)[1]
|
||||
|
||||
|
||||
class AtomicCounter(object):
|
||||
"""A class that atomically increments a counter"""
|
||||
|
||||
def __init__(self, start_count=0):
|
||||
"""Initialize the counter
|
||||
:param start_count: the number to start counting at
|
||||
@ -50,6 +47,7 @@ class AtomicCounter(object):
|
||||
"""Returns the current value of the counter (not atomic)"""
|
||||
return self.__count
|
||||
|
||||
|
||||
def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
"""Generate a function to download a file based on given parameters
|
||||
This works by currying the above given arguments into a closure
|
||||
@ -61,6 +59,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
:param counter: an atomic counter to keep track of # of downloaded files
|
||||
:return: a function that actually downloads a file given these params
|
||||
"""
|
||||
|
||||
def download(d):
|
||||
"""Binds voxforge_url, archive_dir, total, and counter into this scope
|
||||
Downloads the given file
|
||||
@ -68,12 +67,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
of the file to download and file is the name of the file to download
|
||||
"""
|
||||
(i, file) = d
|
||||
download_url = voxforge_url + '/' + file
|
||||
download_url = voxforge_url + "/" + file
|
||||
c = counter.increment()
|
||||
print('Downloading file {} ({}/{})...'.format(i+1, c, total))
|
||||
print("Downloading file {} ({}/{})...".format(i + 1, c, total))
|
||||
maybe_download(filename_of(download_url), archive_dir, download_url)
|
||||
|
||||
return download
|
||||
|
||||
|
||||
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
|
||||
"""Generate a function to extract a tar file based on given parameters
|
||||
This works by currying the above given arguments into a closure
|
||||
@ -86,6 +87,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
:param counter: an atomic counter to keep track of # of extracted files
|
||||
:return: a function that actually extracts a tar file given these params
|
||||
"""
|
||||
|
||||
def extract(d):
|
||||
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
|
||||
Extracts the given file
|
||||
@ -95,58 +97,74 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
(i, archive) = d
|
||||
if i < number_of_test:
|
||||
dataset_dir = path.join(data_dir, "test")
|
||||
elif i<number_of_test+number_of_dev:
|
||||
elif i < number_of_test + number_of_dev:
|
||||
dataset_dir = path.join(data_dir, "dev")
|
||||
else:
|
||||
dataset_dir = path.join(data_dir, "train")
|
||||
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
||||
if not gfile.Exists(
|
||||
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
|
||||
):
|
||||
c = counter.increment()
|
||||
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
|
||||
print("Extracting file {} ({}/{})...".format(i + 1, c, total))
|
||||
tar = tarfile.open(archive)
|
||||
tar.extractall(dataset_dir)
|
||||
tar.close()
|
||||
|
||||
return extract
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data to data_dir
|
||||
if not path.isdir(data_dir):
|
||||
makedirs(data_dir)
|
||||
|
||||
archive_dir = data_dir+"/archive"
|
||||
archive_dir = data_dir + "/archive"
|
||||
if not path.isdir(archive_dir):
|
||||
makedirs(archive_dir)
|
||||
|
||||
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir))
|
||||
print(
|
||||
"Downloading Voxforge data set into {} if not already present...".format(
|
||||
archive_dir
|
||||
)
|
||||
)
|
||||
|
||||
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit'
|
||||
voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
|
||||
html_page = urllib.request.urlopen(voxforge_url)
|
||||
soup = BeautifulSoup(html_page, 'html.parser')
|
||||
soup = BeautifulSoup(html_page, "html.parser")
|
||||
|
||||
# list all links
|
||||
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']]
|
||||
refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
|
||||
|
||||
# download files in parallel
|
||||
print('{} files to download'.format(len(refs)))
|
||||
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter())
|
||||
print("{} files to download".format(len(refs)))
|
||||
downloader = _parallel_downloader(
|
||||
voxforge_url, archive_dir, len(refs), AtomicCounter()
|
||||
)
|
||||
p = ThreadPool(NUM_PARALLEL)
|
||||
p.map(downloader, enumerate(refs))
|
||||
|
||||
# Conditionally extract data to dataset_dir
|
||||
if not path.isdir(path.join(data_dir,"test")):
|
||||
makedirs(path.join(data_dir,"test"))
|
||||
if not path.isdir(path.join(data_dir,"dev")):
|
||||
makedirs(path.join(data_dir,"dev"))
|
||||
if not path.isdir(path.join(data_dir,"train")):
|
||||
makedirs(path.join(data_dir,"train"))
|
||||
if not path.isdir(os.path.join(data_dir, "test")):
|
||||
makedirs(os.path.join(data_dir, "test"))
|
||||
if not path.isdir(os.path.join(data_dir, "dev")):
|
||||
makedirs(os.path.join(data_dir, "dev"))
|
||||
if not path.isdir(os.path.join(data_dir, "train")):
|
||||
makedirs(os.path.join(data_dir, "train"))
|
||||
|
||||
tarfiles = glob(path.join(archive_dir, "*.tgz"))
|
||||
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
|
||||
number_of_files = len(tarfiles)
|
||||
number_of_test = number_of_files//100
|
||||
number_of_dev = number_of_files//100
|
||||
number_of_test = number_of_files // 100
|
||||
number_of_dev = number_of_files // 100
|
||||
|
||||
# extract tars in parallel
|
||||
print("Extracting Voxforge data set into {} if not already present...".format(data_dir))
|
||||
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter())
|
||||
print(
|
||||
"Extracting Voxforge data set into {} if not already present...".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
extracter = _parallel_extracter(
|
||||
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
|
||||
)
|
||||
p.map(extracter, enumerate(tarfiles))
|
||||
|
||||
# Generate data set
|
||||
@ -156,42 +174,50 @@ def _download_and_preprocess_data(data_dir):
|
||||
train_files = _generate_dataset(data_dir, "train")
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
|
||||
|
||||
def _generate_dataset(data_dir, data_set):
|
||||
extracted_dir = path.join(data_dir, data_set)
|
||||
files = []
|
||||
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
||||
if path.isdir(path.join(promts_file[:-11],"wav")):
|
||||
with codecs.open(promts_file, 'r', 'utf-8') as f:
|
||||
for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
|
||||
if path.isdir(os.path.join(promts_file[:-11], "wav")):
|
||||
with codecs.open(promts_file, "r", "utf-8") as f:
|
||||
for line in f:
|
||||
id = line.split(' ')[0].split('/')[-1]
|
||||
sentence = ' '.join(line.split(' ')[1:])
|
||||
sentence = re.sub("[^a-z']"," ",sentence.strip().lower())
|
||||
id = line.split(" ")[0].split("/")[-1]
|
||||
sentence = " ".join(line.split(" ")[1:])
|
||||
sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
|
||||
transcript = ""
|
||||
for token in sentence.split(" "):
|
||||
word = token.strip()
|
||||
if word!="" and word!=" ":
|
||||
if word != "" and word != " ":
|
||||
transcript += word + " "
|
||||
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
|
||||
if gfile.Exists(wav_file):
|
||||
wav_filesize = path.getsize(wav_file)
|
||||
# remove audios that are shorter than 0.5s and longer than 20s.
|
||||
# remove audios that are too short for transcript.
|
||||
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \
|
||||
wav_filesize/len(transcript)>1400:
|
||||
files.append((path.abspath(wav_file), wav_filesize, transcript))
|
||||
if (
|
||||
(wav_filesize / 32000) > 0.5
|
||||
and (wav_filesize / 32000) < 20
|
||||
and transcript != ""
|
||||
and wav_filesize / len(transcript) > 1400
|
||||
):
|
||||
files.append(
|
||||
(os.path.abspath(wav_file), wav_filesize, transcript)
|
||||
)
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,15 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import sys
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
|
||||
def main():
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
|
||||
graph_def = tfv1.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
print('\n'.join(sorted(set(n.op for n in graph_def.node))))
|
||||
print("\n".join(sorted(set(n.op for n in graph_def.node))))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
47
bin/play.py
47
bin/play.py
@ -3,19 +3,13 @@
|
||||
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import random
|
||||
import sys
|
||||
|
||||
from util.sample_collections import samples_from_file, LabeledSample
|
||||
from util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
|
||||
|
||||
|
||||
def play_sample(samples, index):
|
||||
@ -24,7 +18,7 @@ def play_sample(samples, index):
|
||||
if CLI_ARGS.random:
|
||||
index = random.randint(0, len(samples))
|
||||
elif index >= len(samples):
|
||||
print('No sample with index {}'.format(CLI_ARGS.start))
|
||||
print("No sample with index {}".format(CLI_ARGS.start))
|
||||
sys.exit(1)
|
||||
sample = samples[index]
|
||||
print('Sample "{}"'.format(sample.sample_id))
|
||||
@ -50,13 +44,28 @@ def play_collection():
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
|
||||
'and DeepSpeech CSV files')
|
||||
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
|
||||
parser.add_argument('--start', type=int, default=0,
|
||||
help='Sample index to start at (negative numbers are relative to the end of the collection)')
|
||||
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
|
||||
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for playing samples from Sample Databases (SDB files) "
|
||||
"and DeepSpeech CSV files"
|
||||
)
|
||||
parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Sample index to start at (negative numbers are relative to the end of the collection)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of samples to play (-1 for endless)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random",
|
||||
action="store_true",
|
||||
help="If samples should be played in random order",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -70,5 +79,5 @@ if __name__ == "__main__":
|
||||
try:
|
||||
play_collection()
|
||||
except KeyboardInterrupt:
|
||||
print(' Stopped')
|
||||
print(" Stopped")
|
||||
sys.exit(0)
|
||||
|
@ -1,17 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/
|
||||
$ python3 -m venv $HOME/tmp/deepspeech-train-venv/
|
||||
|
||||
Once this command completes successfully, the environment will be ready to be activated.
|
||||
|
||||
@ -46,7 +46,7 @@ Install the required dependencies using ``pip3``\ :
|
||||
.. code-block:: bash
|
||||
|
||||
cd DeepSpeech
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install -e .
|
||||
|
||||
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
|
||||
|
||||
@ -70,7 +70,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 uninstall tensorflow
|
||||
pip3 install 'tensorflow-gpu==1.15.0'
|
||||
pip3 install 'tensorflow-gpu==1.15.2'
|
||||
|
||||
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.
|
||||
|
||||
|
158
evaluate.py
Executable file → Normal file
158
evaluate.py
Executable file → Normal file
@ -2,155 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.feeding import create_dataset
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version
|
||||
from util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values, converting tokens to strings using ``alphabet``.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [[] for _ in range(sp_tuple[2][0])]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x,
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
seq_length=batch_x_len,
|
||||
dropout=no_dropout)
|
||||
|
||||
# Transpose to batch major and apply softmax for decoder
|
||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||
|
||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
bar = create_progressbar(prefix='Test epoch | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Test epoch...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
bar.update(step_count)
|
||||
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
print('Testing model on {}'.format(csv))
|
||||
samples.extend(run_test(init_op, dataset=csv))
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
try:
|
||||
from deepspeech_training import evaluate as ds_evaluate
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
ds_evaluate.run_script()
|
||||
|
@ -10,13 +10,12 @@ import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
from functools import partial
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.flags import create_flags
|
||||
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
|
||||
from deepspeech_training.util.flags import create_flags
|
||||
from functools import partial
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from six.moves import zip, range
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
|
@ -2,19 +2,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, print_function
|
||||
|
||||
import sys
|
||||
|
||||
import optuna
|
||||
import absl.app
|
||||
from ds_ctcdecoder import Scorer
|
||||
import optuna
|
||||
import sys
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from DeepSpeech import create_model
|
||||
from evaluate import evaluate
|
||||
from util.config import Config, initialize_globals
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error
|
||||
from util.evaluate_tools import wer_cer_batch
|
||||
from deepspeech_training.evaluate import evaluate
|
||||
from deepspeech_training.train import create_model
|
||||
from deepspeech_training.util.config import Config, initialize_globals
|
||||
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||
from deepspeech_training.util.logging import log_error
|
||||
from deepspeech_training.util.evaluate_tools import wer_cer_batch
|
||||
from ds_ctcdecoder import Scorer
|
||||
|
||||
|
||||
def character_based():
|
||||
|
@ -1,24 +0,0 @@
|
||||
# Main training requirements
|
||||
tensorflow == 1.15.2
|
||||
numpy == 1.18.1
|
||||
progressbar2
|
||||
six
|
||||
pyxdg
|
||||
attrdict
|
||||
absl-py
|
||||
semver
|
||||
opuslib == 2.0.0
|
||||
|
||||
# Requirements for building native_client files
|
||||
setuptools
|
||||
|
||||
# Requirements for importers
|
||||
sox
|
||||
bs4
|
||||
pandas
|
||||
requests
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
# Requirements for optimizer
|
||||
optuna
|
60
setup.py
Normal file
60
setup.py
Normal file
@ -0,0 +1,60 @@
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def main():
|
||||
version_file = Path(__file__).parent / 'VERSION'
|
||||
with open(str(version_file)) as fin:
|
||||
version = fin.read().strip()
|
||||
|
||||
setup(
|
||||
name='deepspeech_training',
|
||||
version=version,
|
||||
description='Training code for mozilla DeepSpeech',
|
||||
url='https://github.com/mozilla/DeepSpeech',
|
||||
author='Mozilla',
|
||||
license='MPL-2.0',
|
||||
# Classifiers help users find your project by categorizing it.
|
||||
#
|
||||
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Multimedia :: Sound/Audio :: Speech',
|
||||
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
|
||||
'Programming Language :: Python :: 3',
|
||||
],
|
||||
package_dir={'': 'training'},
|
||||
packages=find_packages(where='training'),
|
||||
python_requires='>=3.5, <4',
|
||||
install_requires=[
|
||||
'tensorflow == 1.15.2',
|
||||
'numpy == 1.18.1',
|
||||
'progressbar2',
|
||||
'six',
|
||||
'pyxdg',
|
||||
'attrdict',
|
||||
'absl-py',
|
||||
'semver',
|
||||
'opuslib == 2.0.0',
|
||||
'optuna',
|
||||
'sox',
|
||||
'bs4',
|
||||
'pandas',
|
||||
'requests',
|
||||
'librosa',
|
||||
'soundfile',
|
||||
],
|
||||
# If there are data files included in your packages that need to be
|
||||
# installed, specify them here.
|
||||
package_data={
|
||||
'deepspeech_training': [
|
||||
'VERSION',
|
||||
'GRAPH_VERSION',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
43
stats.py
43
stats.py
@ -1,10 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import functools
|
||||
import pandas
|
||||
|
||||
from deepspeech_training.util.helpers import secs_to_hours
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_csvs(csv_files):
|
||||
# Relative paths are relative to CSV location
|
||||
def absolutify(csv, path):
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
return str(csv.parent / path)
|
||||
|
||||
sets = []
|
||||
for csv in csv_files:
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
|
||||
sets.append(file)
|
||||
|
||||
# Concat all sets, drop any extra columns, re-index the final result as 0..N
|
||||
return pandas.concat(sets, join='inner', ignore_index=True)
|
||||
|
||||
from util.helpers import secs_to_hours
|
||||
from util.feeding import read_csvs
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -14,20 +33,16 @@ def main():
|
||||
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
|
||||
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
|
||||
args = parser.parse_args()
|
||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
|
||||
|
||||
csv_dataframe = read_csvs(in_files)
|
||||
total_bytes = csv_dataframe['wav_filesize'].sum()
|
||||
total_files = len(csv_dataframe.index)
|
||||
total_files = len(csv_dataframe)
|
||||
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
|
||||
|
||||
bytes_without_headers = total_bytes - 44 * total_files
|
||||
|
||||
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8)
|
||||
|
||||
print('total_bytes', total_bytes)
|
||||
print('total_files', total_files)
|
||||
print('bytes_without_headers', bytes_without_headers)
|
||||
print('total_time', secs_to_hours(total_time))
|
||||
print('Total bytes:', total_bytes)
|
||||
print('Total files:', total_files)
|
||||
print('Total time:', secs_to_hours(total_seconds))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
|
||||
set -o pipefail
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
which deepspeech
|
||||
|
@ -17,7 +17,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
||||
|
@ -16,7 +16,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
|
@ -14,7 +14,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
|
@ -1,10 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from argparse import Namespace
|
||||
from .importers import validate_label_eng, get_validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng, get_validate_label
|
||||
from pathlib import Path
|
||||
|
||||
def from_here(path):
|
||||
here = Path(__file__)
|
||||
return here.parent / path
|
||||
|
||||
class TestValidateLabelEng(unittest.TestCase):
|
||||
|
||||
def test_numbers(self):
|
||||
label = validate_label_eng("this is a 1 2 3 test")
|
||||
self.assertEqual(label, None)
|
||||
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
|
||||
self.assertEqual(f('toto1234[{[{[]'), None)
|
||||
|
||||
def test_get_validate_label_missing(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
|
||||
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
|
||||
f = get_validate_label(args)
|
||||
self.assertEqual(f, None)
|
||||
|
||||
def test_get_validate_label(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
|
||||
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
|
||||
f = get_validate_label(args)
|
||||
l = f('toto')
|
||||
self.assertEqual(l, 'toto')
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from .text import Alphabet
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
class TestAlphabetParsing(unittest.TestCase):
|
||||
|
1
training/deepspeech_training/GRAPH_VERSION
Symbolic link
1
training/deepspeech_training/GRAPH_VERSION
Symbolic link
@ -0,0 +1 @@
|
||||
../../GRAPH_VERSION
|
1
training/deepspeech_training/VERSION
Symbolic link
1
training/deepspeech_training/VERSION
Symbolic link
@ -0,0 +1 @@
|
||||
../../VERSION
|
0
training/deepspeech_training/__init__.py
Normal file
0
training/deepspeech_training/__init__.py
Normal file
159
training/deepspeech_training/evaluate.py
Executable file
159
training/deepspeech_training/evaluate.py
Executable file
@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values, converting tokens to strings using ``alphabet``.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [[] for _ in range(sp_tuple[2][0])]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x,
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
seq_length=batch_x_len,
|
||||
dropout=no_dropout)
|
||||
|
||||
# Transpose to batch major and apply softmax for decoder
|
||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||
|
||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
bar = create_progressbar(prefix='Test epoch | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Test epoch...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
bar.update(step_count)
|
||||
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
print('Testing model on {}'.format(csv))
|
||||
samples.extend(run_test(init_op, dataset=csv))
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_script()
|
936
training/deepspeech_training/train.py
Normal file
936
training/deepspeech_training/train.py
Normal file
@ -0,0 +1,936 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
output, output_state = fw_cell(inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode='linear_input',
|
||||
direction='unidirectional',
|
||||
dtype=tf.float32)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||
sequence_lengths=seq_length)
|
||||
|
||||
return output, output_state
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope='cell_0')
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers['input_reshaped'] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers['rnn_output'] = output
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||
layers['raw_logits'] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
|
||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# the loss function used by our network should be the CTC loss function
|
||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
# Calculate the logits of the batch
|
||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
# =================
|
||||
|
||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||
# because, generally, it requires less fine-tuning.
|
||||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
return optimizer
|
||||
|
||||
|
||||
# Towers
|
||||
# ======
|
||||
|
||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||
# In particular, one must introduce a means to isolate the inference and gradient
|
||||
# calculations on the various GPU's.
|
||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||
# A tower is specified by two properties:
|
||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||
# is a means to isolate the operations within a tower.
|
||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||
# on which all operations within the tower execute.
|
||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||
|
||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
r'''
|
||||
With this preliminary step out of the way, we can for each GPU introduce a
|
||||
tower for which's batch we calculate and return the optimization gradients
|
||||
and the average loss across towers.
|
||||
'''
|
||||
# To calculate the mean of the losses
|
||||
tower_avg_losses = []
|
||||
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain tower's avg losses
|
||||
tower_avg_losses.append(avg_loss)
|
||||
|
||||
# Compute gradients for model parameters using tower's mini-batch
|
||||
gradients = optimizer.compute_gradients(avg_loss)
|
||||
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
r'''
|
||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||
Note also that this code acts as a synchronization point as it requires all
|
||||
GPUs to be finished with their mini-batch before it can run to completion.
|
||||
'''
|
||||
# List of average gradients to return to the caller
|
||||
average_grads = []
|
||||
|
||||
# Run this on cpu_device to conserve GPU memory
|
||||
with tf.device(Config.cpu_device):
|
||||
# Loop over gradient/variable pairs from all towers
|
||||
for grad_and_vars in zip(*tower_gradients):
|
||||
# Introduce grads to store the gradients for the current variable
|
||||
grads = []
|
||||
|
||||
# Loop over the gradients for the current variable
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension
|
||||
grad = tf.concat(grads, 0)
|
||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||
|
||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||
grad_and_var = (grad, grad_and_vars[0][1])
|
||||
|
||||
# Add the current tuple to average_grads
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
# Return result to caller
|
||||
return average_grads
|
||||
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r'''
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
'''
|
||||
name = variable.name.replace(':', '_')
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r'''
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
'''
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in DS_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r'''
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||
FLAGS.export_author_id,
|
||||
FLAGS.export_model_name,
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||
f.write('runtime: {}\n'.format(model_runtime))
|
||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||
f.write('---\n')
|
||||
f.write('{}\n'.format(FLAGS.export_description))
|
||||
|
||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_script()
|
0
training/deepspeech_training/util/__init__.py
Normal file
0
training/deepspeech_training/util/__init__.py
Normal file
@ -5,7 +5,7 @@ import tempfile
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from util.helpers import LimitingPool
|
||||
from .helpers import LimitingPool
|
||||
|
||||
DEFAULT_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
@ -2,8 +2,8 @@ import sys
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.logging import log_info, log_error, log_warn
|
||||
from .flags import FLAGS
|
||||
from .logging import log_info, log_error, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path):
|
@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
|
||||
from attrdict import AttrDict
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from util.helpers import parse_file_size
|
||||
from .flags import FLAGS
|
||||
from .gpu import get_available_gpus
|
||||
from .logging import log_error
|
||||
from .text import Alphabet, UTF8Alphabet
|
||||
from .helpers import parse_file_size
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
@ -7,8 +7,8 @@ import numpy as np
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.text import levenshtein
|
||||
from .flags import FLAGS
|
||||
from .text import levenshtein
|
||||
|
||||
|
||||
def pmap(fun, iterable):
|
@ -8,13 +8,13 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||
|
||||
from util.config import Config
|
||||
from util.text import text_to_char_array
|
||||
from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
||||
from util.sample_collections import samples_from_files
|
||||
from util.helpers import remember_exception, MEGABYTE
|
||||
from .config import Config
|
||||
from .text import text_to_char_array
|
||||
from .flags import FLAGS
|
||||
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
||||
from .sample_collections import samples_from_files
|
||||
from .helpers import remember_exception, MEGABYTE
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
@ -4,7 +4,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from util.helpers import secs_to_hours
|
||||
from .helpers import secs_to_hours
|
||||
from collections import Counter
|
||||
|
||||
def get_counter():
|
@ -3,7 +3,7 @@ from __future__ import print_function
|
||||
import progressbar
|
||||
import sys
|
||||
|
||||
from util.flags import FLAGS
|
||||
from .flags import FLAGS
|
||||
|
||||
|
||||
# Logging functions
|
@ -5,8 +5,8 @@ import json
|
||||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from util.helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
||||
|
||||
BIG_ENDIAN = 'big'
|
||||
INT_SIZE = 4
|
@ -1,6 +1,7 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from util.sparse_image_warp import sparse_image_warp
|
||||
|
||||
from .sparse_image_warp import sparse_image_warp
|
||||
|
||||
def augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=30,
|
167
training/deepspeech_training/util/taskcluster.py
Normal file
167
training/deepspeech_training/util/taskcluster.py
Normal file
@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function, absolute_import, division
|
||||
|
||||
import argparse
|
||||
import errno
|
||||
import gzip
|
||||
import os
|
||||
import platform
|
||||
import six.moves.urllib as urllib
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from pkg_resources import parse_version
|
||||
|
||||
|
||||
DEFAULT_SCHEMES = {
|
||||
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
|
||||
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
|
||||
}
|
||||
|
||||
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
|
||||
|
||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||
assert arch_string is not None
|
||||
assert artifact_name is not None
|
||||
assert artifact_name
|
||||
assert branch_name is not None
|
||||
assert branch_name
|
||||
|
||||
return TASKCLUSTER_SCHEME % {'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||
|
||||
def maybe_download_tc(target_dir, tc_url, progress=True):
|
||||
def report_progress(count, block_size, total_size):
|
||||
percent = (count * block_size * 100) // total_size
|
||||
sys.stdout.write("\rDownloading: %d%%" % percent)
|
||||
sys.stdout.flush()
|
||||
|
||||
if percent >= 100:
|
||||
print('\n')
|
||||
|
||||
assert target_dir is not None
|
||||
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
try:
|
||||
os.makedirs(target_dir)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise e
|
||||
assert os.path.isdir(os.path.dirname(target_dir))
|
||||
|
||||
tc_filename = os.path.basename(tc_url)
|
||||
target_file = os.path.join(target_dir, tc_filename)
|
||||
is_gzip = False
|
||||
if not os.path.isfile(target_file):
|
||||
print('Downloading %s ...' % tc_url)
|
||||
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
|
||||
is_gzip = headers.get('Content-Encoding') == 'gzip'
|
||||
else:
|
||||
print('File already exists: %s' % target_file)
|
||||
|
||||
if is_gzip:
|
||||
with open(target_file, "r+b") as frw:
|
||||
decompressed = gzip.decompress(frw.read())
|
||||
frw.seek(0)
|
||||
frw.write(decompressed)
|
||||
frw.truncate()
|
||||
|
||||
return target_file
|
||||
|
||||
def maybe_download_tc_bin(**kwargs):
|
||||
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
|
||||
final_stat = os.stat(final_file)
|
||||
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='Where to put the native client binary files')
|
||||
parser.add_argument('--arch', required=False,
|
||||
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
|
||||
parser.add_argument('--artifact', required=False,
|
||||
default='native_client.tar.xz',
|
||||
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
|
||||
parser.add_argument('--source', required=False, default=None,
|
||||
help='Name of the TaskCluster scheme to use.')
|
||||
parser.add_argument('--branch', required=False,
|
||||
help='Branch name to use. Defaulting to current content of VERSION file.')
|
||||
parser.add_argument('--decoder', action='store_true',
|
||||
help='Get URL to ds_ctcdecoder Python package.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.target and not args.decoder:
|
||||
print('Pass either --target or --decoder.')
|
||||
sys.exit(1)
|
||||
|
||||
is_arm = 'arm' in platform.machine()
|
||||
is_mac = 'darwin' in sys.platform
|
||||
is_64bit = sys.maxsize > (2**31 - 1)
|
||||
is_ucs2 = sys.maxunicode < 0x10ffff
|
||||
|
||||
if not args.arch:
|
||||
if is_arm:
|
||||
args.arch = 'arm64' if is_64bit else 'arm'
|
||||
elif is_mac:
|
||||
args.arch = 'osx'
|
||||
else:
|
||||
args.arch = 'cpu'
|
||||
|
||||
if not args.branch:
|
||||
version_string = read('../VERSION').strip()
|
||||
ds_version = parse_version(version_string)
|
||||
args.branch = "v{}".format(version_string)
|
||||
else:
|
||||
ds_version = parse_version(args.branch)
|
||||
|
||||
if args.decoder:
|
||||
plat = platform.system().lower()
|
||||
arch = platform.machine()
|
||||
|
||||
if plat == 'linux' and arch == 'x86_64':
|
||||
plat = 'manylinux1'
|
||||
|
||||
if plat == 'darwin':
|
||||
plat = 'macosx_10_10'
|
||||
|
||||
m_or_mu = 'mu' if is_ucs2 else 'm'
|
||||
pyver = ''.join(map(str, sys.version_info[0:2]))
|
||||
|
||||
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
|
||||
ds_version=ds_version,
|
||||
pyver=pyver,
|
||||
m_or_mu=m_or_mu,
|
||||
platform=plat,
|
||||
arch=arch
|
||||
)
|
||||
|
||||
ctc_arch = args.arch + '-ctc'
|
||||
|
||||
print(get_tc_url(ctc_arch, artifact, args.branch))
|
||||
sys.exit(0)
|
||||
|
||||
if args.source is not None:
|
||||
if args.source in DEFAULT_SCHEMES:
|
||||
global TASKCLUSTER_SCHEME
|
||||
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
|
||||
else:
|
||||
print('No such scheme: %s' % args.source)
|
||||
sys.exit(1)
|
||||
|
||||
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
|
||||
|
||||
if args.artifact == "convert_graphdef_memmapped_format":
|
||||
convert_graph_file = os.path.join(args.target, args.artifact)
|
||||
final_stat = os.stat(convert_graph_file)
|
||||
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
if '.tar.' in args.artifact:
|
||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -12,13 +12,13 @@ tflogging.set_verbosity(tflogging.ERROR)
|
||||
import logging
|
||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||
|
||||
from multiprocessing import Process, cpu_count
|
||||
from deepspeech_training.util.audio import AudioFile
|
||||
from deepspeech_training.util.config import Config, initialize_globals
|
||||
from deepspeech_training.util.feeding import split_audio_file
|
||||
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||
from deepspeech_training.util.logging import log_error, log_info, log_progress, create_progressbar
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from util.config import Config, initialize_globals
|
||||
from util.audio import AudioFile
|
||||
from util.feeding import split_audio_file
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error, log_info, log_progress, create_progressbar
|
||||
from multiprocessing import Process, cpu_count
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
@ -27,8 +27,8 @@ def fail(message, code=1):
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from deepspeech_training.util.checkpoints import load_or_init_graph
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
|
@ -1,159 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import csv
|
||||
|
||||
from threading import Thread
|
||||
from time import time
|
||||
from scipy.interpolate import spline
|
||||
|
||||
from six.moves import range
|
||||
# Do this to be able to use without X
|
||||
import matplotlib as mpl
|
||||
mpl.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
class GPUUsage(Thread):
|
||||
def __init__(self, csvfile=None):
|
||||
super(GPUUsage, self).__init__()
|
||||
|
||||
self._cmd = [ 'nvidia-smi', 'dmon', '-d', '1', '-s', 'pucvmet' ]
|
||||
self._names = []
|
||||
self._units = []
|
||||
self._process = None
|
||||
|
||||
self._csv_output = csvfile or os.environ.get('ds_gpu_usage_csv', self.make_basename(prefix='ds-gpu-usage', extension='csv'))
|
||||
|
||||
def get_git_desc(self):
|
||||
return subprocess.check_output(['git', 'describe', '--always', '--abbrev']).strip()
|
||||
|
||||
def make_basename(self, prefix, extension):
|
||||
# Let us assume that this code is executed in the current git clone
|
||||
return '%s.%s.%s.%s' % (prefix, self.get_git_desc(), int(time()), extension)
|
||||
|
||||
def stop(self):
|
||||
if not self._process:
|
||||
print("Trying to stop nvidia-smi but no more process, please fix.")
|
||||
return
|
||||
|
||||
print("Ending nvidia-smi monitoring: PID", self._process.pid)
|
||||
self._process.terminate()
|
||||
print("Ended nvidia-smi monitoring ...")
|
||||
|
||||
def run(self):
|
||||
print("Starting nvidia-smi monitoring")
|
||||
|
||||
# If the system has no CUDA setup, then this will fail.
|
||||
try:
|
||||
self._process = subprocess.Popen(self._cmd, stdout=subprocess.PIPE)
|
||||
except OSError as ex:
|
||||
print("Unable to start monitoring, check your environment:", ex)
|
||||
return
|
||||
|
||||
writer = None
|
||||
with open(self._csv_output, 'w') as f:
|
||||
for line in iter(self._process.stdout.readline, ''):
|
||||
d = self.ingest(line)
|
||||
|
||||
if line.startswith('# '):
|
||||
if len(self._names) == 0:
|
||||
self._names = d
|
||||
writer = csv.DictWriter(f, delimiter=str(','), quotechar=str('"'), fieldnames=d)
|
||||
writer.writeheader()
|
||||
continue
|
||||
if len(self._units) == 0:
|
||||
self._units = d
|
||||
continue
|
||||
else:
|
||||
assert len(self._names) == len(self._units)
|
||||
assert len(d) == len(self._names)
|
||||
assert len(d) > 1
|
||||
writer.writerow(self.merge_line(d))
|
||||
f.flush()
|
||||
|
||||
def ingest(self, line):
|
||||
return map(lambda x: x.replace('-', '0'), filter(lambda x: len(x) > 0, map(lambda x: x.strip(), line.split(' ')[1:])))
|
||||
|
||||
def merge_line(self, line):
|
||||
return dict(zip(self._names, line))
|
||||
|
||||
class GPUUsageChart():
|
||||
def __init__(self, source, basename=None):
|
||||
self._rows = [ 'pwr', 'temp', 'sm', 'mem']
|
||||
self._titles = {
|
||||
'pwr': "Power (W)",
|
||||
'temp': "Temperature (°C)",
|
||||
'sm': "Streaming Multiprocessors (%)",
|
||||
'mem': "Memory (%)"
|
||||
}
|
||||
self._data = { }.fromkeys(self._rows)
|
||||
self._csv = source
|
||||
self._basename = basename or os.environ.get('ds_gpu_usage_charts', 'gpu_usage_%%s_%d.png' % int(time.time()))
|
||||
|
||||
# This should make sure we start from anything clean.
|
||||
plt.close("all")
|
||||
|
||||
try:
|
||||
self.read()
|
||||
for plot in self._rows:
|
||||
self.produce_plot(plot)
|
||||
except IOError as ex:
|
||||
print("Unable to read", ex)
|
||||
|
||||
def append_data(self, row):
|
||||
for bucket, value in row.iteritems():
|
||||
if not bucket in self._rows:
|
||||
continue
|
||||
|
||||
if not self._data[bucket]:
|
||||
self._data[bucket] = {}
|
||||
|
||||
gpu = int(row['gpu'])
|
||||
if not self._data[bucket].has_key(gpu):
|
||||
self._data[bucket][gpu] = [ value ]
|
||||
else:
|
||||
self._data[bucket][gpu] += [ value ]
|
||||
|
||||
def read(self):
|
||||
print("Reading data from", self._csv)
|
||||
with open(self._csv, 'r') as f:
|
||||
for r in csv.DictReader(f):
|
||||
self.append_data(r)
|
||||
|
||||
def produce_plot(self, key, with_spline=True):
|
||||
png = self._basename % (key, )
|
||||
print("Producing plot for", key, "as", png)
|
||||
fig, axis = plt.subplots()
|
||||
data = self._data[key]
|
||||
if data is None:
|
||||
print("Data was empty, aborting")
|
||||
return
|
||||
|
||||
x = list(range(len(data[0])))
|
||||
if with_spline:
|
||||
x = map(lambda x: float(x), x)
|
||||
x_sm = np.array(x)
|
||||
x_smooth = np.linspace(x_sm.min(), x_sm.max(), 300)
|
||||
|
||||
for gpu, y in data.iteritems():
|
||||
if with_spline:
|
||||
y = map(lambda x: float(x), y)
|
||||
y_sm = np.array(y)
|
||||
y_smooth = spline(x, y, x_smooth, order=1)
|
||||
axis.plot(x_smooth, y_smooth, label='GPU %d' % (gpu))
|
||||
else:
|
||||
axis.plot(x, y, label='GPU %d' % (gpu))
|
||||
|
||||
axis.legend(loc="upper right", frameon=False)
|
||||
axis.set_xlabel("Time (s)")
|
||||
axis.set_ylabel("%s" % self._titles[key])
|
||||
fig.set_size_inches(24, 18)
|
||||
plt.title("GPU Usage: %s" % self._titles[key])
|
||||
plt.savefig(png, dpi=100)
|
||||
plt.close(fig)
|
@ -1,168 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function, absolute_import, division
|
||||
|
||||
import argparse
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import errno
|
||||
import stat
|
||||
import gzip
|
||||
|
||||
import six.moves.urllib as urllib
|
||||
|
||||
from pkg_resources import parse_version
|
||||
|
||||
|
||||
DEFAULT_SCHEMES = {
|
||||
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
|
||||
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
|
||||
}
|
||||
|
||||
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
|
||||
|
||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||
assert arch_string is not None
|
||||
assert artifact_name is not None
|
||||
assert artifact_name
|
||||
assert branch_name is not None
|
||||
assert branch_name
|
||||
|
||||
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||
|
||||
def maybe_download_tc(target_dir, tc_url, progress=True):
|
||||
def report_progress(count, block_size, total_size):
|
||||
percent = (count * block_size * 100) // total_size
|
||||
sys.stdout.write("\rDownloading: %d%%" % percent)
|
||||
sys.stdout.flush()
|
||||
|
||||
if percent >= 100:
|
||||
print('\n')
|
||||
|
||||
assert target_dir is not None
|
||||
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
try:
|
||||
os.makedirs(target_dir)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise e
|
||||
assert os.path.isdir(os.path.dirname(target_dir))
|
||||
|
||||
tc_filename = os.path.basename(tc_url)
|
||||
target_file = os.path.join(target_dir, tc_filename)
|
||||
is_gzip = False
|
||||
if not os.path.isfile(target_file):
|
||||
print('Downloading %s ...' % tc_url)
|
||||
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
|
||||
is_gzip = headers.get('Content-Encoding') == 'gzip'
|
||||
else:
|
||||
print('File already exists: %s' % target_file)
|
||||
|
||||
if is_gzip:
|
||||
with open(target_file, "r+b") as frw:
|
||||
decompressed = gzip.decompress(frw.read())
|
||||
frw.seek(0)
|
||||
frw.write(decompressed)
|
||||
frw.truncate()
|
||||
|
||||
return target_file
|
||||
|
||||
def maybe_download_tc_bin(**kwargs):
|
||||
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
|
||||
final_stat = os.stat(final_file)
|
||||
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='Where to put the native client binary files')
|
||||
parser.add_argument('--arch', required=False,
|
||||
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
|
||||
parser.add_argument('--artifact', required=False,
|
||||
default='native_client.tar.xz',
|
||||
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
|
||||
parser.add_argument('--source', required=False, default=None,
|
||||
help='Name of the TaskCluster scheme to use.')
|
||||
parser.add_argument('--branch', required=False,
|
||||
help='Branch name to use. Defaulting to current content of VERSION file.')
|
||||
parser.add_argument('--decoder', action='store_true',
|
||||
help='Get URL to ds_ctcdecoder Python package.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.target and not args.decoder:
|
||||
print('Pass either --target or --decoder.')
|
||||
exit(1)
|
||||
|
||||
is_arm = 'arm' in platform.machine()
|
||||
is_mac = 'darwin' in sys.platform
|
||||
is_64bit = sys.maxsize > (2**31 - 1)
|
||||
is_ucs2 = sys.maxunicode < 0x10ffff
|
||||
|
||||
if not args.arch:
|
||||
if is_arm:
|
||||
args.arch = 'arm64' if is_64bit else 'arm'
|
||||
elif is_mac:
|
||||
args.arch = 'osx'
|
||||
else:
|
||||
args.arch = 'cpu'
|
||||
|
||||
if not args.branch:
|
||||
version_string = read('../VERSION').strip()
|
||||
ds_version = parse_version(version_string)
|
||||
args.branch = "v{}".format(version_string)
|
||||
else:
|
||||
ds_version = parse_version(args.branch)
|
||||
|
||||
if args.decoder:
|
||||
plat = platform.system().lower()
|
||||
arch = platform.machine()
|
||||
|
||||
if plat == 'linux' and arch == 'x86_64':
|
||||
plat = 'manylinux1'
|
||||
|
||||
if plat == 'darwin':
|
||||
plat = 'macosx_10_10'
|
||||
|
||||
m_or_mu = 'mu' if is_ucs2 else 'm'
|
||||
pyver = ''.join(map(str, sys.version_info[0:2]))
|
||||
|
||||
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
|
||||
ds_version=ds_version,
|
||||
pyver=pyver,
|
||||
m_or_mu=m_or_mu,
|
||||
platform=plat,
|
||||
arch=arch
|
||||
)
|
||||
|
||||
ctc_arch = args.arch + '-ctc'
|
||||
|
||||
print(get_tc_url(ctc_arch, artifact, args.branch))
|
||||
exit(0)
|
||||
|
||||
if args.source is not None:
|
||||
if args.source in DEFAULT_SCHEMES:
|
||||
global TASKCLUSTER_SCHEME
|
||||
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
|
||||
else:
|
||||
print('No such scheme: %s' % args.source)
|
||||
exit(1)
|
||||
|
||||
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
|
||||
|
||||
if args.artifact == "convert_graphdef_memmapped_format":
|
||||
convert_graph_file = os.path.join(args.target, args.artifact)
|
||||
final_stat = os.stat(convert_graph_file)
|
||||
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
if '.tar.' in args.artifact:
|
||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
try:
|
||||
from deepspeech_training.util import taskcluster as dsu_taskcluster
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
dsu_taskcluster.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user