Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
00e4dbe3fd
27
.compute
27
.compute
@ -2,15 +2,15 @@
|
||||
|
||||
set -xe
|
||||
|
||||
apt-get install -y python3-venv
|
||||
apt-get install -y python3-venv libopus0
|
||||
|
||||
python3 -m venv /tmp/venv
|
||||
source /tmp/venv/bin/activate
|
||||
|
||||
pip install -r <(grep -v tensorflow requirements.txt)
|
||||
pip install tensorflow-gpu==1.15.0
|
||||
|
||||
# Install ds_ctcdecoder package from TaskCluster
|
||||
pip install $(python3 util/taskcluster.py --decoder)
|
||||
pip install -U setuptools wheel pip
|
||||
pip install .
|
||||
pip uninstall -y tensorflow
|
||||
pip install tensorflow-gpu==1.14
|
||||
|
||||
mkdir -p ../keep/summaries
|
||||
|
||||
@ -18,17 +18,22 @@ data="${SHARED_DIR}/data"
|
||||
fis="${data}/LDC/fisher"
|
||||
swb="${data}/LDC/LDC97S62/swb"
|
||||
lbs="${data}/OpenSLR/LibriSpeech/librivox"
|
||||
cv="${data}/mozilla/CommonVoice/en_1087h_2019-06-12/clips"
|
||||
npr="${data}/NPR/WAMU/sets/v0.3"
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
--train_files "${fis}-train.csv","${swb}-train.csv","${lbs}-train-clean-100.csv","${lbs}-train-clean-360.csv","${lbs}-train-other-500.csv" \
|
||||
--dev_files "${lbs}-dev-clean.csv"\
|
||||
--test_files "${lbs}-test-clean.csv" \
|
||||
--train_files "${npr}/best-train.sdb","${npr}/good-train.sdb","${cv}/train.sdb","${fis}-train.sdb","${swb}-train.sdb","${lbs}-train-clean-100.sdb","${lbs}-train-clean-360.sdb","${lbs}-train-other-500.sdb" \
|
||||
--dev_files "${lbs}-dev-clean.sdb" \
|
||||
--test_files "${lbs}-test-clean.sdb" \
|
||||
--train_batch_size 24 \
|
||||
--dev_batch_size 48 \
|
||||
--test_batch_size 48 \
|
||||
--train_cudnn \
|
||||
--n_hidden 2048 \
|
||||
--learning_rate 0.0001 \
|
||||
--dropout_rate 0.2 \
|
||||
--epoch 13 \
|
||||
--dropout_rate 0.40 \
|
||||
--epochs 150 \
|
||||
--noearly_stop \
|
||||
--feature_cache "../tmp/feature.cache" \
|
||||
--checkpoint_dir "../keep" \
|
||||
--summary_dir "../keep/summaries"
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -19,6 +19,7 @@
|
||||
/native_client/python/model_wrap.cpp
|
||||
/native_client/python/utils_wrap.cpp
|
||||
/native_client/javascript/build
|
||||
/native_client/javascript/client.js
|
||||
/native_client/javascript/deepspeech_wrap.cxx
|
||||
/doc/.build/
|
||||
/doc/xml-c/
|
||||
|
4
.isort.cfg
Normal file
4
.isort.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[settings]
|
||||
line_length=80
|
||||
multi_line_output=3
|
||||
default_section=FIRSTPARTY
|
@ -9,7 +9,7 @@ python:
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- stage: cardboard linter
|
||||
- name: cardboard linter
|
||||
install:
|
||||
- pip install --upgrade cardboardlint pylint
|
||||
script:
|
||||
@ -17,9 +17,10 @@ jobs:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
|
||||
fi
|
||||
- stage: python unit tests
|
||||
- name: python unit tests
|
||||
install:
|
||||
- pip install --upgrade -r requirements_tests.txt
|
||||
- pip install --upgrade -r requirements_tests.txt;
|
||||
pip install --upgrade .
|
||||
script:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
python -m unittest;
|
||||
|
937
DeepSpeech.py
937
DeepSpeech.py
@ -2,934 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
output, output_state = fw_cell(inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode='linear_input',
|
||||
direction='unidirectional',
|
||||
dtype=tf.float32)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||
sequence_lengths=seq_length)
|
||||
|
||||
return output, output_state
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope='cell_0')
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers['input_reshaped'] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers['rnn_output'] = output
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||
layers['raw_logits'] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
|
||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# the loss function used by our network should be the CTC loss function
|
||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
# Calculate the logits of the batch
|
||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
# =================
|
||||
|
||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||
# because, generally, it requires less fine-tuning.
|
||||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
return optimizer
|
||||
|
||||
|
||||
# Towers
|
||||
# ======
|
||||
|
||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||
# In particular, one must introduce a means to isolate the inference and gradient
|
||||
# calculations on the various GPU's.
|
||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||
# A tower is specified by two properties:
|
||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||
# is a means to isolate the operations within a tower.
|
||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||
# on which all operations within the tower execute.
|
||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||
|
||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
r'''
|
||||
With this preliminary step out of the way, we can for each GPU introduce a
|
||||
tower for which's batch we calculate and return the optimization gradients
|
||||
and the average loss across towers.
|
||||
'''
|
||||
# To calculate the mean of the losses
|
||||
tower_avg_losses = []
|
||||
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain tower's avg losses
|
||||
tower_avg_losses.append(avg_loss)
|
||||
|
||||
# Compute gradients for model parameters using tower's mini-batch
|
||||
gradients = optimizer.compute_gradients(avg_loss)
|
||||
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
r'''
|
||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||
Note also that this code acts as a synchronization point as it requires all
|
||||
GPUs to be finished with their mini-batch before it can run to completion.
|
||||
'''
|
||||
# List of average gradients to return to the caller
|
||||
average_grads = []
|
||||
|
||||
# Run this on cpu_device to conserve GPU memory
|
||||
with tf.device(Config.cpu_device):
|
||||
# Loop over gradient/variable pairs from all towers
|
||||
for grad_and_vars in zip(*tower_gradients):
|
||||
# Introduce grads to store the gradients for the current variable
|
||||
grads = []
|
||||
|
||||
# Loop over the gradients for the current variable
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension
|
||||
grad = tf.concat(grads, 0)
|
||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||
|
||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||
grad_and_var = (grad, grad_and_vars[0][1])
|
||||
|
||||
# Add the current tuple to average_grads
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
# Return result to caller
|
||||
return average_grads
|
||||
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r'''
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
'''
|
||||
name = variable.name.replace(':', '_')
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r'''
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
'''
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in DS_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r'''
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||
FLAGS.export_author_id,
|
||||
FLAGS.export_model_name,
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||
f.write('runtime: {}\n'.format(model_runtime))
|
||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||
f.write('---\n')
|
||||
f.write('{}\n'.format(FLAGS.export_description))
|
||||
|
||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
try:
|
||||
from deepspeech_training import train as ds_train
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
ds_train.run_script()
|
||||
|
11
Dockerfile
11
Dockerfile
@ -24,7 +24,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libsox-fmt-mp3 \
|
||||
htop \
|
||||
nano \
|
||||
swig \
|
||||
cmake \
|
||||
libboost-all-dev \
|
||||
zlib1g-dev \
|
||||
@ -66,7 +65,7 @@ RUN wget https://bootstrap.pypa.io/get-pip.py && \
|
||||
|
||||
# >> START Configure Tensorflow Build
|
||||
|
||||
# Clone TensoFlow from Mozilla repo
|
||||
# Clone TensorFlow from Mozilla repo
|
||||
RUN git clone https://github.com/mozilla/tensorflow/
|
||||
WORKDIR /tensorflow
|
||||
RUN git checkout r1.15
|
||||
@ -150,7 +149,7 @@ COPY . /DeepSpeech/
|
||||
|
||||
WORKDIR /DeepSpeech
|
||||
|
||||
RUN pip3 --no-cache-dir install -r requirements.txt
|
||||
RUN pip3 --no-cache-dir install .
|
||||
|
||||
# Link DeepSpeech native_client libs to tf folder
|
||||
RUN ln -s /DeepSpeech/native_client /tensorflow
|
||||
@ -201,10 +200,10 @@ WORKDIR /DeepSpeech/native_client
|
||||
RUN make deepspeech
|
||||
WORKDIR /DeepSpeech/native_client/python
|
||||
RUN make bindings
|
||||
RUN pip3 install dist/deepspeech*
|
||||
RUN pip3 install --upgrade dist/deepspeech*
|
||||
WORKDIR /DeepSpeech/native_client/ctcdecode
|
||||
RUN make
|
||||
RUN pip3 install dist/*.whl
|
||||
RUN make bindings
|
||||
RUN pip3 install --upgrade dist/*.whl
|
||||
|
||||
|
||||
# << END Build and bind
|
||||
|
@ -1,53 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Tool for building Sample Databases (SDB files) from DeepSpeech CSV files and other SDB files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import progressbar
|
||||
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
||||
from util.sample_collections import samples_from_files, DirectSDBWriter
|
||||
from deepspeech_training.util.audio import (
|
||||
AUDIO_TYPE_OPUS,
|
||||
AUDIO_TYPE_WAV,
|
||||
change_audio_types,
|
||||
)
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.sample_collections import (
|
||||
DirectSDBWriter,
|
||||
samples_from_files,
|
||||
)
|
||||
|
||||
AUDIO_TYPE_LOOKUP = {
|
||||
'wav': AUDIO_TYPE_WAV,
|
||||
'opus': AUDIO_TYPE_OPUS
|
||||
}
|
||||
AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
|
||||
|
||||
|
||||
def build_sdb():
|
||||
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
|
||||
with DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled) as sdb_writer:
|
||||
with DirectSDBWriter(
|
||||
CLI_ARGS.target, audio_type=audio_type, labeled=not CLI_ARGS.unlabeled
|
||||
) as sdb_writer:
|
||||
samples = samples_from_files(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
|
||||
bar = progressbar.ProgressBar(max_value=len(samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(change_audio_types(samples, audio_type=audio_type, processes=CLI_ARGS.workers)):
|
||||
for sample in bar(
|
||||
change_audio_types(
|
||||
samples, audio_type=audio_type, processes=CLI_ARGS.workers
|
||||
)
|
||||
):
|
||||
sdb_writer.add(sample)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for building Sample Databases (SDB files) '
|
||||
'from DeepSpeech CSV files and other SDB files')
|
||||
parser.add_argument('sources', nargs='+',
|
||||
help='Source CSV and/or SDB files - '
|
||||
'Note: For getting a correctly ordered target SDB, source SDBs have to have their samples '
|
||||
'already ordered from shortest to longest.')
|
||||
parser.add_argument('target', help='SDB file to create')
|
||||
parser.add_argument('--audio-type', default='opus', choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help='Audio representation inside target SDB')
|
||||
parser.add_argument('--workers', type=int, default=None,
|
||||
help='Number of encoding SDB workers')
|
||||
parser.add_argument('--unlabeled', action='store_true',
|
||||
help='If to build an SDB with unlabeled (audio only) samples - '
|
||||
'typically used for building noise augmentation corpora')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for building Sample Databases (SDB files) "
|
||||
"from DeepSpeech CSV files and other SDB files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"sources",
|
||||
nargs="+",
|
||||
help="Source CSV and/or SDB files - "
|
||||
"Note: For getting a correctly ordered target SDB, source SDBs have to have their samples "
|
||||
"already ordered from shortest to longest.",
|
||||
)
|
||||
parser.add_argument("target", help="SDB file to create")
|
||||
parser.add_argument(
|
||||
"--audio-type",
|
||||
default="opus",
|
||||
choices=AUDIO_TYPE_LOOKUP.keys(),
|
||||
help="Audio representation inside target SDB",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=None, help="Number of encoding SDB workers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unlabeled",
|
||||
action="store_true",
|
||||
help="If to build an SDB with unlabeled (audio only) samples - "
|
||||
"typically used for building noise augmentation corpora",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
from util.gpu_usage import GPUUsage
|
||||
|
||||
gu = GPUUsage()
|
||||
gu.start()
|
@ -1,10 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
from util.gpu_usage import GPUUsageChart
|
||||
|
||||
GPUUsageChart(sys.argv[1], sys.argv[2])
|
@ -1,20 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import sys
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from google.protobuf import text_format
|
||||
|
||||
|
||||
def main():
|
||||
# Load and export as string
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
|
||||
graph_def = tfv1.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
with tfv1.gfile.FastGFile(sys.argv[1] + 'txt', 'w') as fout:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1] + "txt", "w") as fout:
|
||||
fout.write(text_format.MessageToString(graph_def))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,23 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -25,9 +19,9 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'aidatatang_200zh')
|
||||
main_folder = os.path.join(target_dir, "aidatatang_200zh")
|
||||
|
||||
for targz in glob.glob(os.path.join(main_folder, 'corpus', '*', '*.tar.gz')):
|
||||
for targz in glob.glob(os.path.join(main_folder, "corpus", "*", "*.tar.gz")):
|
||||
extract(targz, os.path.dirname(targz))
|
||||
|
||||
# Folder structure is now:
|
||||
@ -46,9 +40,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
|
||||
# Since the transcripts themselves can contain spaces, we split on space but
|
||||
# only once, then build a mapping from file name to transcript
|
||||
transcripts_path = os.path.join(main_folder, 'transcript', 'aidatatang_200_zh_transcript.txt')
|
||||
transcripts_path = os.path.join(
|
||||
main_folder, "transcript", "aidatatang_200_zh_transcript.txt"
|
||||
)
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
|
||||
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -57,33 +53,39 @@ def preprocess_data(tgz_file, target_dir):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.splitext(os.path.basename(wav))[0]
|
||||
transcript = transcripts[transcript_key].strip('\n')
|
||||
transcript = transcripts[transcript_key].strip("\n")
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, 'corpus', subset, '*', '*.wav'))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(
|
||||
os.path.join(main_folder, "corpus", subset, "*", "*.wav")
|
||||
)
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
|
||||
|
||||
# Trim train set to under 10s by removing the last couple hundred samples
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'aidatatang_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "aidatatang_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/62/
|
||||
parser = get_importers_parser(description='Import aidatatang_200zh corpus')
|
||||
parser.add_argument('tgz_file', help='Path to aidatatang_200zh.tgz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import aidatatang_200zh corpus")
|
||||
parser.add_argument("tgz_file", help="Path to aidatatang_200zh.tgz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -1,23 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import pandas
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -25,10 +19,10 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'data_aishell')
|
||||
main_folder = os.path.join(target_dir, "data_aishell")
|
||||
|
||||
wav_archives_folder = os.path.join(main_folder, 'wav')
|
||||
for targz in glob.glob(os.path.join(wav_archives_folder, '*.tar.gz')):
|
||||
wav_archives_folder = os.path.join(main_folder, "wav")
|
||||
for targz in glob.glob(os.path.join(wav_archives_folder, "*.tar.gz")):
|
||||
extract(targz, main_folder)
|
||||
|
||||
# Folder structure is now:
|
||||
@ -45,9 +39,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
|
||||
# Since the transcripts themselves can contain spaces, we split on space but
|
||||
# only once, then build a mapping from file name to transcript
|
||||
transcripts_path = os.path.join(main_folder, 'transcript', 'aishell_transcript_v0.8.txt')
|
||||
transcripts_path = os.path.join(
|
||||
main_folder, "transcript", "aishell_transcript_v0.8.txt"
|
||||
)
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = dict((line.split(' ', maxsplit=1) for line in fin))
|
||||
transcripts = dict((line.split(" ", maxsplit=1) for line in fin))
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -56,33 +52,37 @@ def preprocess_data(tgz_file, target_dir):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.splitext(os.path.basename(wav))[0]
|
||||
transcript = transcripts[transcript_key].strip('\n')
|
||||
transcript = transcripts[transcript_key].strip("\n")
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, subset, 'S*', '*.wav'))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(os.path.join(main_folder, subset, "S*", "*.wav"))
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMNNAMES)
|
||||
|
||||
# Trim train set to under 10s by removing the last couple hundred samples
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'aishell_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "aishell_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# http://www.openslr.org/33/
|
||||
parser = get_importers_parser(description='Import AISHELL corpus')
|
||||
parser.add_argument('aishell_tgz_file', help='Path to data_aishell.tgz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import AISHELL corpus")
|
||||
parser.add_argument("aishell_tgz_file", help="Path to data_aishell.tgz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
109
bin/import_cv.py
109
bin/import_cv.py
@ -1,34 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import sox
|
||||
import tarfile
|
||||
import os
|
||||
import subprocess
|
||||
import progressbar
|
||||
|
||||
import tarfile
|
||||
from glob import glob
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
ARCHIVE_DIR_NAME = 'cv_corpus_v1'
|
||||
ARCHIVE_NAME = ARCHIVE_DIR_NAME + '.tar.gz'
|
||||
ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "cv_corpus_v1"
|
||||
ARCHIVE_NAME = ARCHIVE_DIR_NAME + ".tar.gz"
|
||||
ARCHIVE_URL = (
|
||||
"https://s3.us-east-2.amazonaws.com/common-voice-data-download/" + ARCHIVE_NAME
|
||||
)
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract common voice data
|
||||
@ -36,56 +37,70 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Conditionally convert common voice CSV files and mp3 data to DeepSpeech CSVs and wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(path.join(extracted_dir, '*.csv')):
|
||||
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(os.path.join(extracted_dir, "*.csv")):
|
||||
_maybe_convert_set(
|
||||
extracted_dir,
|
||||
source_csv,
|
||||
os.path.join(target_dir, os.path.split(source_csv)[-1]),
|
||||
)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
mp3_filename = sample[0]
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(["soxi", "-s", wav_filename], stderr=subprocess.STDOUT)
|
||||
)
|
||||
file_size = -1
|
||||
if path.exists(wav_filename):
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = validate_label(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
print()
|
||||
if path.exists(target_csv):
|
||||
if os.path.exists(target_csv):
|
||||
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
|
||||
return
|
||||
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
|
||||
@ -93,14 +108,14 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
with open(source_csv) as source_csv_file:
|
||||
reader = csv.DictReader(source_csv_file)
|
||||
for row in reader:
|
||||
samples.append((os.path.join(extracted_dir, row['filename']), row['text']))
|
||||
samples.append((os.path.join(extracted_dir, row["filename"]), row["text"]))
|
||||
|
||||
# Mutable counters for the concurrent embedded routine
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
rows = []
|
||||
|
||||
print('Importing mp3 files...')
|
||||
print("Importing mp3 files...")
|
||||
pool = Pool()
|
||||
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
|
||||
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1):
|
||||
@ -112,21 +127,28 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
pool.join()
|
||||
|
||||
print('Writing "%s"...' % target_csv)
|
||||
with open(target_csv, 'w') as target_csv_file:
|
||||
with open(target_csv, "w") as target_csv_file:
|
||||
writer = csv.DictWriter(target_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
writer.writerow({ 'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript })
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
@ -134,5 +156,6 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
except sox.core.SoxError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,90 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Broadly speaking, this script takes the audio downloaded from Common Voice
|
||||
for a certain language, in addition to the *.tsv files output by CorporaCreator,
|
||||
and the script formats the data and transcripts to be in a state usable by
|
||||
DeepSpeech.py
|
||||
Use "python3 import_cv2.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
"""
|
||||
import csv
|
||||
import sox
|
||||
import os
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.text import Alphabet
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
|
||||
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
|
||||
for dataset in ["train", "test", "dev", "validated", "other"]:
|
||||
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset + ".tsv")
|
||||
if os.path.isfile(input_tsv):
|
||||
print("Loading TSV file: ", input_tsv)
|
||||
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
mp3_filename = sample[0]
|
||||
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == ".mp3":
|
||||
mp3_filename += ".mp3"
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter_fun(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((os.path.split(wav_filename)[-1], file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||
output_csv = os.path.join(
|
||||
audio_dir, os.path.split(input_tsv)[-1].replace("tsv", "csv")
|
||||
)
|
||||
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
with open(input_tsv, encoding='utf-8') as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
||||
with open(input_tsv, encoding="utf-8") as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter="\t")
|
||||
for row in reader:
|
||||
samples.append((path.join(audio_dir, row['path']), row['sentence']))
|
||||
samples.append((os.path.join(audio_dir, row["path"]), row["sentence"]))
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -101,26 +107,38 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(output_csv, 'w', encoding='utf-8') as output_csv_file:
|
||||
print('Writing CSV file for DeepSpeech.py as: ', output_csv)
|
||||
with open(output_csv, "w", encoding="utf-8") as output_csv_file:
|
||||
print("Writing CSV file for DeepSpeech.py as: ", output_csv)
|
||||
writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(rows), widgets=SIMPLE_BAR)
|
||||
for filename, file_size, transcript in bar(rows):
|
||||
if space_after_every_character:
|
||||
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': ' '.join(transcript)})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": " ".join(transcript),
|
||||
}
|
||||
)
|
||||
else:
|
||||
writer.writerow({'wav_filename': filename, 'wav_filesize': file_size, 'transcript': transcript})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": filename,
|
||||
"wav_filesize": file_size,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
@ -130,24 +148,42 @@ def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PARSER = get_importers_parser(description='Import CommonVoice v2.0 corpora')
|
||||
PARSER.add_argument('tsv_dir', help='Directory containing tsv files')
|
||||
PARSER.add_argument('--audio_dir', help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"')
|
||||
PARSER.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
PARSER.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
|
||||
PARSER = get_importers_parser(description="Import CommonVoice v2.0 corpora")
|
||||
PARSER.add_argument("tsv_dir", help="Directory containing tsv files")
|
||||
PARSER.add_argument(
|
||||
"--audio_dir",
|
||||
help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"',
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
PARSER.add_argument(
|
||||
"--space_after_every_character",
|
||||
action="store_true",
|
||||
help="To help transcript join by white space",
|
||||
)
|
||||
|
||||
PARAMS = PARSER.parse_args()
|
||||
validate_label = get_validate_label(PARAMS)
|
||||
|
||||
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
|
||||
AUDIO_DIR = (
|
||||
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
|
||||
)
|
||||
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
|
||||
|
||||
def label_filter_fun(label):
|
||||
if PARAMS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -1,25 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
import librosa
|
||||
import pandas
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
# Prerequisite: Having the sph2pipe tool in your PATH:
|
||||
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import pandas
|
||||
import subprocess
|
||||
import unicodedata
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
@ -29,33 +24,55 @@ def _download_and_preprocess_data(data_dir):
|
||||
_maybe_convert_wav(data_dir, "LDC2005S13", "fisher-2005-wav")
|
||||
|
||||
# Conditionally split Fisher wav data
|
||||
all_2004 = _split_wav_and_sentences(data_dir,
|
||||
original_data="fisher-2004-wav",
|
||||
converted_data="fisher-2004-split-wav",
|
||||
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"))
|
||||
all_2005 = _split_wav_and_sentences(data_dir,
|
||||
original_data="fisher-2005-wav",
|
||||
converted_data="fisher-2005-split-wav",
|
||||
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"))
|
||||
all_2004 = _split_wav_and_sentences(
|
||||
data_dir,
|
||||
original_data="fisher-2004-wav",
|
||||
converted_data="fisher-2004-split-wav",
|
||||
trans_data=os.path.join("LDC2004T19", "fe_03_p1_tran", "data", "trans"),
|
||||
)
|
||||
all_2005 = _split_wav_and_sentences(
|
||||
data_dir,
|
||||
original_data="fisher-2005-wav",
|
||||
converted_data="fisher-2005-split-wav",
|
||||
trans_data=os.path.join("LDC2005T19", "fe_03_p2_tran", "data", "trans"),
|
||||
)
|
||||
|
||||
# The following files have incorrect transcripts that are much longer than
|
||||
# their audio source. The result is that we end up with more labels than time
|
||||
# slices, which breaks CTC.
|
||||
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"), "transcript"] = "correct"
|
||||
all_2004.loc[all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"), "transcript"] = "that's one of those"
|
||||
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"), "transcript"] = "they don't want"
|
||||
all_2005.loc[all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"), "transcript"] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
|
||||
all_2004.loc[
|
||||
all_2004["wav_filename"].str.endswith("fe_03_00265-33.53-33.81.wav"),
|
||||
"transcript",
|
||||
] = "correct"
|
||||
all_2004.loc[
|
||||
all_2004["wav_filename"].str.endswith("fe_03_00991-527.39-528.3.wav"),
|
||||
"transcript",
|
||||
] = "that's one of those"
|
||||
all_2005.loc[
|
||||
all_2005["wav_filename"].str.endswith("fe_03_10282-344.42-344.84.wav"),
|
||||
"transcript",
|
||||
] = "they don't want"
|
||||
all_2005.loc[
|
||||
all_2005["wav_filename"].str.endswith("fe_03_10677-101.04-106.41.wav"),
|
||||
"transcript",
|
||||
] = "uh my mine yeah the german shepherd pitbull mix he snores almost as loud as i do"
|
||||
|
||||
# The following file is just a short sound and not at all transcribed like provided.
|
||||
# So we just exclude it.
|
||||
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")]
|
||||
all_2004 = all_2004[
|
||||
~all_2004["wav_filename"].str.endswith("fe_03_00027-393.8-394.05.wav")
|
||||
]
|
||||
|
||||
# The following file is far too long and would ruin our training batch size.
|
||||
# So we just exclude it.
|
||||
all_2005 = all_2005[~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")]
|
||||
all_2005 = all_2005[
|
||||
~all_2005["wav_filename"].str.endswith("fe_03_11487-31.09-234.06.wav")
|
||||
]
|
||||
|
||||
# The following file is too large for its transcript, so we just exclude it.
|
||||
all_2004 = all_2004[~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")]
|
||||
all_2004 = all_2004[
|
||||
~all_2004["wav_filename"].str.endswith("fe_03_01326-307.42-307.93.wav")
|
||||
]
|
||||
|
||||
# Conditionally split Fisher data into train/validation/test sets
|
||||
train_2004, dev_2004, test_2004 = _split_sets(all_2004)
|
||||
@ -71,6 +88,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(data_dir, "fisher-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "fisher-test.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
source_dir = os.path.join(data_dir, original_data)
|
||||
target_dir = os.path.join(data_dir, converted_data)
|
||||
@ -88,10 +106,18 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
for filename in fnmatch.filter(filenames, "*.sph"):
|
||||
sph_file = os.path.join(root, filename)
|
||||
for channel in ["1", "2"]:
|
||||
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "_c" + channel + ".wav"
|
||||
wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "_c"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(target_dir, wav_filename)
|
||||
print("converting {} to {}".format(sph_file, wav_file))
|
||||
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file])
|
||||
subprocess.check_call(
|
||||
["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file]
|
||||
)
|
||||
|
||||
|
||||
def _parse_transcriptions(trans_file):
|
||||
segments = []
|
||||
@ -109,18 +135,23 @@ def _parse_transcriptions(trans_file):
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
segments.append({
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"speaker": speaker,
|
||||
"transcript": transcript,
|
||||
})
|
||||
segments.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"speaker": speaker,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
return segments
|
||||
|
||||
|
||||
def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data):
|
||||
trans_dir = os.path.join(data_dir, trans_data)
|
||||
source_dir = os.path.join(data_dir, original_data)
|
||||
@ -137,43 +168,73 @@ def _split_wav_and_sentences(data_dir, trans_data, original_data, converted_data
|
||||
segments = _parse_transcriptions(trans_file)
|
||||
|
||||
# Open wav corresponding to transcription file
|
||||
wav_filenames = [os.path.splitext(os.path.basename(trans_file))[0] + "_c" + channel + ".wav" for channel in ["1", "2"]]
|
||||
wav_files = [os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames]
|
||||
wav_filenames = [
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "_c"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
for channel in ["1", "2"]
|
||||
]
|
||||
wav_files = [
|
||||
os.path.join(source_dir, wav_filename) for wav_filename in wav_filenames
|
||||
]
|
||||
|
||||
print("splitting {} according to {}".format(wav_files, trans_file))
|
||||
|
||||
origAudios = [librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files]
|
||||
origAudios = [
|
||||
librosa.load(wav_file, sr=16000, mono=False) for wav_file in wav_files
|
||||
]
|
||||
|
||||
# Loop over segments and split wav_file for each segment
|
||||
for segment in segments:
|
||||
# Create wav segment filename
|
||||
start_time = segment["start_time"]
|
||||
stop_time = segment["stop_time"]
|
||||
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
new_wav_file = os.path.join(target_dir, new_wav_filename)
|
||||
|
||||
channel = 0 if segment["speaker"] == "A:" else 1
|
||||
_split_and_resample_wav(origAudios[channel], start_time, stop_time, new_wav_file)
|
||||
_split_and_resample_wav(
|
||||
origAudios[channel], start_time, stop_time, new_wav_file
|
||||
)
|
||||
|
||||
new_wav_filesize = os.path.getsize(new_wav_file)
|
||||
transcript = validate_label(segment["transcript"])
|
||||
if transcript != None:
|
||||
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
|
||||
files.append(
|
||||
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
|
||||
)
|
||||
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
def _split_audio(origAudio, start_time, stop_time):
|
||||
audioData, frameRate = origAudio
|
||||
nChannels = len(audioData.shape)
|
||||
startIndex = int(start_time * frameRate)
|
||||
stopIndex = int(stop_time * frameRate)
|
||||
return audioData[startIndex: stopIndex] if 1 == nChannels else audioData[:, startIndex: stopIndex]
|
||||
return (
|
||||
audioData[startIndex:stopIndex]
|
||||
if 1 == nChannels
|
||||
else audioData[:, startIndex:stopIndex]
|
||||
)
|
||||
|
||||
|
||||
def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
frameRate = origAudio[1]
|
||||
chunkData = _split_audio(origAudio, start_time, stop_time)
|
||||
soundfile.write(new_wav_file, chunkData, frameRate, "PCM_16")
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
@ -187,9 +248,12 @@ def _split_sets(filelist):
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
|
||||
return (filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end])
|
||||
return (
|
||||
filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import numpy as np
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -26,7 +20,7 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'ST-CMDS-20170001_1-OS')
|
||||
main_folder = os.path.join(target_dir, "ST-CMDS-20170001_1-OS")
|
||||
|
||||
# Folder structure is now:
|
||||
# - ST-CMDS-20170001_1-OS/
|
||||
@ -39,16 +33,16 @@ def preprocess_data(tgz_file, target_dir):
|
||||
for wav in glob.glob(glob_path):
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
txt_filename = os.path.splitext(wav_filename)[0] + '.txt'
|
||||
with open(txt_filename, 'r') as fin:
|
||||
txt_filename = os.path.splitext(wav_filename)[0] + ".txt"
|
||||
with open(txt_filename, "r") as fin:
|
||||
transcript = fin.read()
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
return set_files
|
||||
|
||||
# Load all files, then deterministically split into train/dev/test sets
|
||||
all_files = load_set(os.path.join(main_folder, '*.wav'))
|
||||
all_files = load_set(os.path.join(main_folder, "*.wav"))
|
||||
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
|
||||
df.sort_values(by='wav_filename', inplace=True)
|
||||
df.sort_values(by="wav_filename", inplace=True)
|
||||
|
||||
indices = np.arange(0, len(df))
|
||||
np.random.seed(12345)
|
||||
@ -61,29 +55,33 @@ def preprocess_data(tgz_file, target_dir):
|
||||
train_indices = indices[:-10000]
|
||||
|
||||
train_files = df.iloc[train_indices]
|
||||
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
|
||||
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
|
||||
train_files = train_files[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_train.csv')
|
||||
print('Saving train set into {}...'.format(dest_csv))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_train.csv")
|
||||
print("Saving train set into {}...".format(dest_csv))
|
||||
train_files.to_csv(dest_csv, index=False)
|
||||
|
||||
dev_files = df.iloc[dev_indices]
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_dev.csv')
|
||||
print('Saving dev set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_dev.csv")
|
||||
print("Saving dev set into {}...".format(dest_csv))
|
||||
dev_files.to_csv(dest_csv, index=False)
|
||||
|
||||
test_files = df.iloc[test_indices]
|
||||
dest_csv = os.path.join(target_dir, 'freestmandarin_test.csv')
|
||||
print('Saving test set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "freestmandarin_test.csv")
|
||||
print("Saving test set into {}...".format(dest_csv))
|
||||
test_files.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/38/
|
||||
parser = get_importers_parser(description='Import Free ST Chinese Mandarin corpus')
|
||||
parser.add_argument('tgz_file', help='Path to ST-CMDS-20170001_1-OS.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import Free ST Chinese Mandarin corpus")
|
||||
parser.add_argument("tgz_file", help="Path to ST-CMDS-20170001_1-OS.tar.gz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import math
|
||||
import urllib
|
||||
import logging
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
from os import path
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
|
||||
import swifter
|
||||
import pandas as pd
|
||||
from sox import Transformer
|
||||
|
||||
import swifter
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
__version__ = "0.1.0"
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -40,9 +34,7 @@ def parse_args(args):
|
||||
Returns:
|
||||
:obj:`argparse.Namespace`: command line parameters namespace
|
||||
"""
|
||||
parser = get_importers_parser(
|
||||
description="Imports GramVaani data for Deep Speech"
|
||||
)
|
||||
parser = get_importers_parser(description="Imports GramVaani data for Deep Speech")
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="version",
|
||||
@ -82,6 +74,7 @@ def parse_args(args):
|
||||
)
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def setup_logging(level):
|
||||
"""Setup basic logging
|
||||
Args:
|
||||
@ -92,6 +85,7 @@ def setup_logging(level):
|
||||
level=level, stream=sys.stdout, format=format, datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
|
||||
class GramVaaniCSV:
|
||||
"""GramVaaniCSV representing a GramVaani dataset.
|
||||
Args:
|
||||
@ -107,8 +101,17 @@ class GramVaaniCSV:
|
||||
_logger.info("Parsing csv file...%s", os.path.abspath(csv_filename))
|
||||
data = pd.read_csv(
|
||||
os.path.abspath(csv_filename),
|
||||
names=["piece_id","audio_url","transcript_labelled","transcript","labels","content_filename","audio_length","user_id"],
|
||||
usecols=["audio_url","transcript","audio_length"],
|
||||
names=[
|
||||
"piece_id",
|
||||
"audio_url",
|
||||
"transcript_labelled",
|
||||
"transcript",
|
||||
"labels",
|
||||
"content_filename",
|
||||
"audio_length",
|
||||
"user_id",
|
||||
],
|
||||
usecols=["audio_url", "transcript", "audio_length"],
|
||||
skiprows=[0],
|
||||
engine="python",
|
||||
encoding="utf-8",
|
||||
@ -119,6 +122,7 @@ class GramVaaniCSV:
|
||||
_logger.info("Parsed %d lines csv file." % len(data))
|
||||
return data
|
||||
|
||||
|
||||
class GramVaaniDownloader:
|
||||
"""GramVaaniDownloader downloads a GramVaani dataset.
|
||||
Args:
|
||||
@ -138,15 +142,17 @@ class GramVaaniDownloader:
|
||||
mp3_directory (os.path): The directory into which the associated mp3's were downloaded
|
||||
"""
|
||||
mp3_directory = self._pre_download()
|
||||
self.data.swifter.apply(func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True)
|
||||
self.data.swifter.apply(
|
||||
func=lambda arg: self._download(*arg, mp3_directory), axis=1, raw=True
|
||||
)
|
||||
return mp3_directory
|
||||
|
||||
def _pre_download(self):
|
||||
mp3_directory = path.join(self.target_dir, "mp3")
|
||||
if not path.exists(self.target_dir):
|
||||
mp3_directory = os.path.join(self.target_dir, "mp3")
|
||||
if not os.path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(mp3_directory):
|
||||
if not os.path.exists(mp3_directory):
|
||||
_logger.info("Creating directory...%s", mp3_directory)
|
||||
os.mkdir(mp3_directory)
|
||||
return mp3_directory
|
||||
@ -154,13 +160,14 @@ class GramVaaniDownloader:
|
||||
def _download(self, audio_url, transcript, audio_length, mp3_directory):
|
||||
if audio_url == "audio_url":
|
||||
return
|
||||
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
|
||||
if not path.exists(mp3_filename):
|
||||
mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
|
||||
if not os.path.exists(mp3_filename):
|
||||
_logger.debug("Downloading mp3 file...%s", audio_url)
|
||||
urllib.request.urlretrieve(audio_url, mp3_filename)
|
||||
else:
|
||||
_logger.debug("Already downloaded mp3 file...%s", audio_url)
|
||||
|
||||
|
||||
class GramVaaniConverter:
|
||||
"""GramVaaniConverter converts the mp3's to wav's for a GramVaani dataset.
|
||||
Args:
|
||||
@ -181,37 +188,53 @@ class GramVaaniConverter:
|
||||
wav_directory (os.path): The directory into which the associated wav's were downloaded
|
||||
"""
|
||||
wav_directory = self._pre_convert()
|
||||
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
||||
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
if not path.exists(wav_filename):
|
||||
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
for mp3_filename in self.mp3_directory.glob("**/*.mp3"):
|
||||
wav_filename = os.path.join(
|
||||
wav_directory,
|
||||
os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav",
|
||||
)
|
||||
if not os.path.exists(wav_filename):
|
||||
_logger.debug(
|
||||
"Converting mp3 file %s to wav file %s"
|
||||
% (mp3_filename, wav_filename)
|
||||
)
|
||||
transformer = Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
||||
transformer.convert(
|
||||
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
|
||||
)
|
||||
transformer.build(str(mp3_filename), str(wav_filename))
|
||||
else:
|
||||
_logger.debug("Already converted mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
_logger.debug(
|
||||
"Already converted mp3 file %s to wav file %s"
|
||||
% (mp3_filename, wav_filename)
|
||||
)
|
||||
return wav_directory
|
||||
|
||||
def _pre_convert(self):
|
||||
wav_directory = path.join(self.target_dir, "wav")
|
||||
if not path.exists(self.target_dir):
|
||||
wav_directory = os.path.join(self.target_dir, "wav")
|
||||
if not os.path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(wav_directory):
|
||||
if not os.path.exists(wav_directory):
|
||||
_logger.info("Creating directory...%s", wav_directory)
|
||||
os.mkdir(wav_directory)
|
||||
return wav_directory
|
||||
|
||||
|
||||
class GramVaaniDataSets:
|
||||
def __init__(self, target_dir, wav_directory, gram_vaani_csv):
|
||||
self.target_dir = target_dir
|
||||
self.wav_directory = wav_directory
|
||||
self.csv_data = gram_vaani_csv.data
|
||||
self.raw = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.valid = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.train = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.dev = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.test = pd.DataFrame(columns=["wav_filename","wav_filesize","transcript"])
|
||||
self.raw = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
self.valid = pd.DataFrame(
|
||||
columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
self.train = pd.DataFrame(
|
||||
columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
self.dev = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
self.test = pd.DataFrame(columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
def create(self):
|
||||
self._convert_csv_data_to_raw_data()
|
||||
@ -220,30 +243,45 @@ class GramVaaniDataSets:
|
||||
self.valid = self.valid.sample(frac=1).reset_index(drop=True)
|
||||
train_size, dev_size, test_size = self._calculate_data_set_sizes()
|
||||
self.train = self.valid.loc[0:train_size]
|
||||
self.dev = self.valid.loc[train_size:train_size+dev_size]
|
||||
self.test = self.valid.loc[train_size+dev_size:train_size+dev_size+test_size]
|
||||
self.dev = self.valid.loc[train_size : train_size + dev_size]
|
||||
self.test = self.valid.loc[
|
||||
train_size + dev_size : train_size + dev_size + test_size
|
||||
]
|
||||
|
||||
def _convert_csv_data_to_raw_data(self):
|
||||
self.raw[["wav_filename","wav_filesize","transcript"]] = self.csv_data[
|
||||
["audio_url","transcript","audio_length"]
|
||||
].swifter.apply(func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg), axis=1, raw=True)
|
||||
self.raw[["wav_filename", "wav_filesize", "transcript"]] = self.csv_data[
|
||||
["audio_url", "transcript", "audio_length"]
|
||||
].swifter.apply(
|
||||
func=lambda arg: self._convert_csv_data_to_raw_data_impl(*arg),
|
||||
axis=1,
|
||||
raw=True,
|
||||
)
|
||||
self.raw.reset_index()
|
||||
|
||||
def _convert_csv_data_to_raw_data_impl(self, audio_url, transcript, audio_length):
|
||||
if audio_url == "audio_url":
|
||||
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
||||
mp3_filename = os.path.basename(audio_url)
|
||||
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
|
||||
wav_relative_filename = os.path.join(
|
||||
"wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav"
|
||||
)
|
||||
wav_filesize = os.path.getsize(
|
||||
os.path.join(self.target_dir, wav_relative_filename)
|
||||
)
|
||||
transcript = validate_label(transcript)
|
||||
if None == transcript:
|
||||
transcript = ""
|
||||
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
||||
return pd.Series([wav_relative_filename, wav_filesize, transcript])
|
||||
|
||||
def _is_valid_raw_rows(self):
|
||||
is_valid_raw_transcripts = self._is_valid_raw_transcripts()
|
||||
is_valid_raw_wav_frames = self._is_valid_raw_wav_frames()
|
||||
is_valid_raw_row = [(is_valid_raw_transcript & is_valid_raw_wav_frame) for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(is_valid_raw_transcripts, is_valid_raw_wav_frames)]
|
||||
is_valid_raw_row = [
|
||||
(is_valid_raw_transcript & is_valid_raw_wav_frame)
|
||||
for is_valid_raw_transcript, is_valid_raw_wav_frame in zip(
|
||||
is_valid_raw_transcripts, is_valid_raw_wav_frames
|
||||
)
|
||||
]
|
||||
series = pd.Series(is_valid_raw_row)
|
||||
return series
|
||||
|
||||
@ -252,16 +290,29 @@ class GramVaaniDataSets:
|
||||
|
||||
def _is_valid_raw_wav_frames(self):
|
||||
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
||||
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
||||
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
||||
wav_filepaths = [
|
||||
os.path.join(self.target_dir, str(wav_filename))
|
||||
for wav_filename in self.raw.wav_filename
|
||||
]
|
||||
wav_frames = [
|
||||
int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filepath], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
for wav_filepath in wav_filepaths
|
||||
]
|
||||
is_valid_raw_wav_frames = [
|
||||
self._is_wav_frame_valid(wav_frame, transcript)
|
||||
for wav_frame, transcript in zip(wav_frames, transcripts)
|
||||
]
|
||||
return pd.Series(is_valid_raw_wav_frames)
|
||||
|
||||
def _is_wav_frame_valid(self, wav_frame, transcript):
|
||||
is_wav_frame_valid = True
|
||||
if int(wav_frame/SAMPLE_RATE*1000/10/2) < len(str(transcript)):
|
||||
if int(wav_frame / SAMPLE_RATE * 1000 / 10 / 2) < len(str(transcript)):
|
||||
is_wav_frame_valid = False
|
||||
elif wav_frame/SAMPLE_RATE > MAX_SECS:
|
||||
elif wav_frame / SAMPLE_RATE > MAX_SECS:
|
||||
is_wav_frame_valid = False
|
||||
return is_wav_frame_valid
|
||||
|
||||
@ -280,7 +331,14 @@ class GramVaaniDataSets:
|
||||
def _save(self, dataset):
|
||||
dataset_path = os.path.join(self.target_dir, dataset + ".csv")
|
||||
dataframe = getattr(self, dataset)
|
||||
dataframe.to_csv(dataset_path, index=False, encoding="utf-8", escapechar='\\', quoting=csv.QUOTE_MINIMAL)
|
||||
dataframe.to_csv(
|
||||
dataset_path,
|
||||
index=False,
|
||||
encoding="utf-8",
|
||||
escapechar="\\",
|
||||
quoting=csv.QUOTE_MINIMAL,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main entry point allowing external calls
|
||||
@ -304,4 +362,5 @@ def main(args):
|
||||
datasets.save()
|
||||
_logger.info("Finished GramVaani importer...")
|
||||
|
||||
|
||||
main(sys.argv[1:])
|
||||
|
@ -1,28 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
import sys
|
||||
|
||||
import pandas
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
LDC93S1_BASE = "LDC93S1"
|
||||
LDC93S1_BASE_URL = "https://catalog.ldc.upenn.edu/desc/addenda/"
|
||||
local_file = maybe_download(LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav")
|
||||
trans_file = maybe_download(LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt")
|
||||
local_file = maybe_download(
|
||||
LDC93S1_BASE + ".wav", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".wav"
|
||||
)
|
||||
trans_file = maybe_download(
|
||||
LDC93S1_BASE + ".txt", data_dir, LDC93S1_BASE_URL + LDC93S1_BASE + ".txt"
|
||||
)
|
||||
with open(trans_file, "r") as fin:
|
||||
transcript = ' '.join(fin.read().strip().lower().split(' ')[2:]).replace('.', '')
|
||||
transcript = " ".join(fin.read().strip().lower().split(" ")[2:]).replace(
|
||||
".", ""
|
||||
)
|
||||
|
||||
df = pandas.DataFrame(data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
|
||||
columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
df = pandas.DataFrame(
|
||||
data=[(os.path.abspath(local_file), os.path.getsize(local_file), transcript)],
|
||||
columns=["wav_filename", "wav_filesize", "transcript"],
|
||||
)
|
||||
df.to_csv(os.path.join(data_dir, "ldc93s1.csv"), index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,33 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import pandas
|
||||
import progressbar
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
|
||||
import pandas
|
||||
import progressbar
|
||||
from sox import Transformer
|
||||
from util.downloader import maybe_download
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data to data_dir
|
||||
print("Downloading Librivox data set (55GB) into {} if not already present...".format(data_dir))
|
||||
print(
|
||||
"Downloading Librivox data set (55GB) into {} if not already present...".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
TRAIN_CLEAN_100_URL = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||
TRAIN_CLEAN_360_URL = "http://www.openslr.org/resources/12/train-clean-360.tar.gz"
|
||||
TRAIN_OTHER_500_URL = "http://www.openslr.org/resources/12/train-other-500.tar.gz"
|
||||
TRAIN_CLEAN_100_URL = (
|
||||
"http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||
)
|
||||
TRAIN_CLEAN_360_URL = (
|
||||
"http://www.openslr.org/resources/12/train-clean-360.tar.gz"
|
||||
)
|
||||
TRAIN_OTHER_500_URL = (
|
||||
"http://www.openslr.org/resources/12/train-other-500.tar.gz"
|
||||
)
|
||||
|
||||
DEV_CLEAN_URL = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
|
||||
DEV_OTHER_URL = "http://www.openslr.org/resources/12/dev-other.tar.gz"
|
||||
@ -35,12 +41,20 @@ def _download_and_preprocess_data(data_dir):
|
||||
TEST_CLEAN_URL = "http://www.openslr.org/resources/12/test-clean.tar.gz"
|
||||
TEST_OTHER_URL = "http://www.openslr.org/resources/12/test-other.tar.gz"
|
||||
|
||||
def filename_of(x): return os.path.split(x)[1]
|
||||
train_clean_100 = maybe_download(filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL)
|
||||
def filename_of(x):
|
||||
return os.path.split(x)[1]
|
||||
|
||||
train_clean_100 = maybe_download(
|
||||
filename_of(TRAIN_CLEAN_100_URL), data_dir, TRAIN_CLEAN_100_URL
|
||||
)
|
||||
bar.update(0)
|
||||
train_clean_360 = maybe_download(filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL)
|
||||
train_clean_360 = maybe_download(
|
||||
filename_of(TRAIN_CLEAN_360_URL), data_dir, TRAIN_CLEAN_360_URL
|
||||
)
|
||||
bar.update(1)
|
||||
train_other_500 = maybe_download(filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL)
|
||||
train_other_500 = maybe_download(
|
||||
filename_of(TRAIN_OTHER_500_URL), data_dir, TRAIN_OTHER_500_URL
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
dev_clean = maybe_download(filename_of(DEV_CLEAN_URL), data_dir, DEV_CLEAN_URL)
|
||||
@ -48,9 +62,13 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_other = maybe_download(filename_of(DEV_OTHER_URL), data_dir, DEV_OTHER_URL)
|
||||
bar.update(4)
|
||||
|
||||
test_clean = maybe_download(filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL)
|
||||
test_clean = maybe_download(
|
||||
filename_of(TEST_CLEAN_URL), data_dir, TEST_CLEAN_URL
|
||||
)
|
||||
bar.update(5)
|
||||
test_other = maybe_download(filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL)
|
||||
test_other = maybe_download(
|
||||
filename_of(TEST_OTHER_URL), data_dir, TEST_OTHER_URL
|
||||
)
|
||||
bar.update(6)
|
||||
|
||||
# Conditionally extract LibriSpeech data
|
||||
@ -61,11 +79,17 @@ def _download_and_preprocess_data(data_dir):
|
||||
LIBRIVOX_DIR = "LibriSpeech"
|
||||
work_dir = os.path.join(data_dir, LIBRIVOX_DIR)
|
||||
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-100"), train_clean_100
|
||||
)
|
||||
bar.update(0)
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-clean-360"), train_clean_360
|
||||
)
|
||||
bar.update(1)
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500)
|
||||
_maybe_extract(
|
||||
data_dir, os.path.join(LIBRIVOX_DIR, "train-other-500"), train_other_500
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
_maybe_extract(data_dir, os.path.join(LIBRIVOX_DIR, "dev-clean"), dev_clean)
|
||||
@ -91,28 +115,48 @@ def _download_and_preprocess_data(data_dir):
|
||||
# data_dir/LibriSpeech/split-wav/1-2-2.txt
|
||||
# ...
|
||||
print("Converting FLAC to WAV and splitting transcriptions...")
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
train_100 = _convert_audio_and_split_sentences(work_dir, "train-clean-100", "train-clean-100-wav")
|
||||
with progressbar.ProgressBar(max_value=7, widget=progressbar.AdaptiveETA) as bar:
|
||||
train_100 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-clean-100", "train-clean-100-wav"
|
||||
)
|
||||
bar.update(0)
|
||||
train_360 = _convert_audio_and_split_sentences(work_dir, "train-clean-360", "train-clean-360-wav")
|
||||
train_360 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-clean-360", "train-clean-360-wav"
|
||||
)
|
||||
bar.update(1)
|
||||
train_500 = _convert_audio_and_split_sentences(work_dir, "train-other-500", "train-other-500-wav")
|
||||
train_500 = _convert_audio_and_split_sentences(
|
||||
work_dir, "train-other-500", "train-other-500-wav"
|
||||
)
|
||||
bar.update(2)
|
||||
|
||||
dev_clean = _convert_audio_and_split_sentences(work_dir, "dev-clean", "dev-clean-wav")
|
||||
dev_clean = _convert_audio_and_split_sentences(
|
||||
work_dir, "dev-clean", "dev-clean-wav"
|
||||
)
|
||||
bar.update(3)
|
||||
dev_other = _convert_audio_and_split_sentences(work_dir, "dev-other", "dev-other-wav")
|
||||
dev_other = _convert_audio_and_split_sentences(
|
||||
work_dir, "dev-other", "dev-other-wav"
|
||||
)
|
||||
bar.update(4)
|
||||
|
||||
test_clean = _convert_audio_and_split_sentences(work_dir, "test-clean", "test-clean-wav")
|
||||
test_clean = _convert_audio_and_split_sentences(
|
||||
work_dir, "test-clean", "test-clean-wav"
|
||||
)
|
||||
bar.update(5)
|
||||
test_other = _convert_audio_and_split_sentences(work_dir, "test-other", "test-other-wav")
|
||||
test_other = _convert_audio_and_split_sentences(
|
||||
work_dir, "test-other", "test-other-wav"
|
||||
)
|
||||
bar.update(6)
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
train_100.to_csv(os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False)
|
||||
train_360.to_csv(os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False)
|
||||
train_500.to_csv(os.path.join(data_dir, "librivox-train-other-500.csv"), index=False)
|
||||
train_100.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-clean-100.csv"), index=False
|
||||
)
|
||||
train_360.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-clean-360.csv"), index=False
|
||||
)
|
||||
train_500.to_csv(
|
||||
os.path.join(data_dir, "librivox-train-other-500.csv"), index=False
|
||||
)
|
||||
|
||||
dev_clean.to_csv(os.path.join(data_dir, "librivox-dev-clean.csv"), index=False)
|
||||
dev_other.to_csv(os.path.join(data_dir, "librivox-dev-other.csv"), index=False)
|
||||
@ -120,6 +164,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
test_clean.to_csv(os.path.join(data_dir, "librivox-test-clean.csv"), index=False)
|
||||
test_other.to_csv(os.path.join(data_dir, "librivox-test-other.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_extract(data_dir, extracted_data, archive):
|
||||
# If data_dir/extracted_data does not exist, extract archive in data_dir
|
||||
if not gfile.Exists(os.path.join(data_dir, extracted_data)):
|
||||
@ -127,6 +172,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
|
||||
tar.extractall(data_dir)
|
||||
tar.close()
|
||||
|
||||
|
||||
def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
source_dir = os.path.join(extracted_dir, data_set)
|
||||
target_dir = os.path.join(extracted_dir, dest_dir)
|
||||
@ -149,20 +195,22 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
# We also convert the corresponding FLACs to WAV in the same pass
|
||||
files = []
|
||||
for root, dirnames, filenames in os.walk(source_dir):
|
||||
for filename in fnmatch.filter(filenames, '*.trans.txt'):
|
||||
for filename in fnmatch.filter(filenames, "*.trans.txt"):
|
||||
trans_filename = os.path.join(root, filename)
|
||||
with codecs.open(trans_filename, "r", "utf-8") as fin:
|
||||
for line in fin:
|
||||
# Parse each segment line
|
||||
first_space = line.find(" ")
|
||||
seqid, transcript = line[:first_space], line[first_space+1:]
|
||||
seqid, transcript = line[:first_space], line[first_space + 1 :]
|
||||
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
transcript = transcript.lower().strip()
|
||||
|
||||
@ -177,7 +225,10 @@ def _convert_audio_and_split_sentences(extracted_dir, data_set, dest_dir):
|
||||
|
||||
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,44 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sox
|
||||
import zipfile
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import zipfile
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
|
||||
ARCHIVE_DIR_NAME = 'lingua_libre'
|
||||
ARCHIVE_NAME = 'Q{qId}-{iso639_3}-{language_English_name}.zip'
|
||||
ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "lingua_libre"
|
||||
ARCHIVE_NAME = "Q{qId}-{iso639_3}-{language_English_name}.zip"
|
||||
ARCHIVE_URL = "https://lingualibre.fr/datasets/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -46,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Produce CSV files and convert ogg data to wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -58,57 +54,70 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % archive_path)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
ogg_filename = sample[0]
|
||||
# Storing wav files next to the ogg ones - just with a different suffix
|
||||
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
|
||||
wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(ogg_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
rows = []
|
||||
counter = get_counter()
|
||||
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME + "_" + ARCHIVE_NAME.replace(".zip", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace('.zip', ''))
|
||||
ogg_root_dir = os.path.join(extracted_dir, ARCHIVE_NAME.replace(".zip", ""))
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(ogg_root_dir, '**/*.ogg')
|
||||
glob_dir = os.path.join(ogg_root_dir, "**/*.ogg")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
record_file = record.replace(ogg_root_dir + os.path.sep, '')
|
||||
record_file = record.replace(ogg_root_dir + os.path.sep, "")
|
||||
if record_filter(record_file):
|
||||
samples.append((os.path.join(ogg_root_dir, record_file), os.path.splitext(os.path.basename(record_file))[0]))
|
||||
samples.append(
|
||||
(
|
||||
os.path.join(ogg_root_dir, record_file),
|
||||
os.path.splitext(os.path.basename(record_file))[0],
|
||||
)
|
||||
)
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -125,9 +134,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -139,7 +148,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
transcript = validate_label(item[2])
|
||||
if not transcript:
|
||||
continue
|
||||
wav_filename = os.path.join(ogg_root_dir, item[0].replace('.ogg', '.wav'))
|
||||
wav_filename = os.path.join(
|
||||
ogg_root_dir, item[0].replace(".ogg", ".wav")
|
||||
)
|
||||
i_mod = i % 10
|
||||
if i_mod == 0:
|
||||
writer = test_writer
|
||||
@ -147,38 +158,63 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(ogg_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
transformer.build(ogg_filename, wav_filename)
|
||||
except sox.core.SoxError as ex:
|
||||
print('SoX processing error', ex, ogg_filename, wav_filename)
|
||||
print("SoX processing error", ex, ogg_filename, wav_filename)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--qId', type=int, required=True, help='LinguaLibre language qId')
|
||||
parser.add_argument('--iso639-3', type=str, required=True, help='ISO639-3 language code')
|
||||
parser.add_argument('--english-name', type=str, required=True, help='Enligh name of the language')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--bogus-records', type=argparse.FileType('r'), required=False, help='Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for LinguaLibre dataset. Check https://lingualibre.fr/wiki/Help:Download_from_LinguaLibre for details."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--qId", type=int, required=True, help="LinguaLibre language qId"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iso639-3", type=str, required=True, help="ISO639-3 language code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--english-name", type=str, required=True, help="Enligh name of the language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bogus-records",
|
||||
type=argparse.FileType("r"),
|
||||
required=False,
|
||||
help="Text file listing well-known bogus record to skip from importing, from https://lingualibre.fr/wiki/LinguaLibre:Misleading_items",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
@ -191,15 +227,17 @@ if __name__ == "__main__":
|
||||
|
||||
def record_filter(path):
|
||||
if any(regex.match(path) for regex in bogus_regexes):
|
||||
print('Reject', path)
|
||||
print("Reject", path)
|
||||
return False
|
||||
return True
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
@ -208,6 +246,14 @@ if __name__ == "__main__":
|
||||
label = None
|
||||
return label
|
||||
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
|
||||
ARCHIVE_URL = ARCHIVE_URL.format(qId=CLI_ARGS.qId, iso639_3=CLI_ARGS.iso639_3, language_English_name=CLI_ARGS.english_name)
|
||||
ARCHIVE_NAME = ARCHIVE_NAME.format(
|
||||
qId=CLI_ARGS.qId,
|
||||
iso639_3=CLI_ARGS.iso639_3,
|
||||
language_English_name=CLI_ARGS.english_name,
|
||||
)
|
||||
ARCHIVE_URL = ARCHIVE_URL.format(
|
||||
qId=CLI_ARGS.qId,
|
||||
iso639_3=CLI_ARGS.iso639_3,
|
||||
language_English_name=CLI_ARGS.english_name,
|
||||
)
|
||||
_download_and_preprocess_data(target_dir=CLI_ARGS.target_dir)
|
||||
|
@ -1,43 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import os
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import tarfile
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import unicodedata
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
import progressbar
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
|
||||
ARCHIVE_DIR_NAME = '{language}'
|
||||
ARCHIVE_NAME = '{language}.tgz'
|
||||
ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "{language}"
|
||||
ARCHIVE_NAME = "{language}.tgz"
|
||||
ARCHIVE_URL = "http://www.caito.de/data/Training/stt_tts/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -48,8 +42,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -65,9 +59,13 @@ def one_sample(sample):
|
||||
wav_filename = sample[0]
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
rows = []
|
||||
@ -75,27 +73,30 @@ def one_sample(sample):
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
print("conversion failure", wav_filename)
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tgz", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
@ -103,14 +104,16 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(wav_root_dir, '**/metadata.csv')
|
||||
glob_dir = os.path.join(wav_root_dir, "**/metadata.csv")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
if any(map(lambda sk: sk in record, SKIP_LIST)): # pylint: disable=cell-var-from-loop
|
||||
if any(
|
||||
map(lambda sk: sk in record, SKIP_LIST)
|
||||
): # pylint: disable=cell-var-from-loop
|
||||
continue
|
||||
with open(record, 'r') as rec:
|
||||
with open(record, "r") as rec:
|
||||
for re in rec.readlines():
|
||||
re = re.strip().split('|')
|
||||
audio = os.path.join(os.path.dirname(record), 'wavs', re[0] + '.wav')
|
||||
re = re.strip().split("|")
|
||||
audio = os.path.join(os.path.dirname(record), "wavs", re[0] + ".wav")
|
||||
transcript = re[2]
|
||||
samples.append((audio, transcript))
|
||||
|
||||
@ -129,9 +132,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -151,39 +154,60 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=os.path.relpath(wav_filename, extracted_dir),
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=os.path.relpath(wav_filename, extracted_dir),
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--skiplist', type=str, default='', help='Directories / books to skip, comma separated')
|
||||
parser.add_argument('--language', required=True, type=str, help='Dataset language to use')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for M-AILABS dataset. https://www.caito.de/2019/01/the-m-ailabs-speech-dataset/."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skiplist",
|
||||
type=str,
|
||||
default="",
|
||||
help="Directories / books to skip, comma separated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language", required=True, type=str, help="Dataset language to use"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(","))
|
||||
validate_label = get_validate_label(CLI_ARGS)
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -1,30 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
import wave
|
||||
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
|
||||
def is_file_truncated(wav_filename, wav_filesize):
|
||||
with wave.open(wav_filename, mode='rb') as fin:
|
||||
with wave.open(wav_filename, mode="rb") as fin:
|
||||
assert fin.getframerate() == 16000
|
||||
assert fin.getsampwidth() == 2
|
||||
assert fin.getnchannels() == 1
|
||||
@ -37,8 +31,13 @@ def is_file_truncated(wav_filename, wav_filesize):
|
||||
|
||||
def preprocess_data(folder_with_archives, target_dir):
|
||||
# First extract subset archives
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
extract(os.path.join(folder_with_archives, 'magicdata_{}_set.tar.gz'.format(subset)), target_dir)
|
||||
for subset in ("train", "dev", "test"):
|
||||
extract(
|
||||
os.path.join(
|
||||
folder_with_archives, "magicdata_{}_set.tar.gz".format(subset)
|
||||
),
|
||||
target_dir,
|
||||
)
|
||||
|
||||
# Folder structure is now:
|
||||
# - magicdata_{train,dev,test}.tar.gz
|
||||
@ -54,58 +53,73 @@ def preprocess_data(folder_with_archives, target_dir):
|
||||
# name, one containing the speaker ID, and one containing the transcription
|
||||
|
||||
def load_set(set_path):
|
||||
transcripts = pandas.read_csv(os.path.join(set_path, 'TRANS.txt'), sep='\t', index_col=0)
|
||||
glob_path = os.path.join(set_path, '*', '*.wav')
|
||||
transcripts = pandas.read_csv(
|
||||
os.path.join(set_path, "TRANS.txt"), sep="\t", index_col=0
|
||||
)
|
||||
glob_path = os.path.join(set_path, "*", "*.wav")
|
||||
set_files = []
|
||||
for wav in glob.glob(glob_path):
|
||||
try:
|
||||
wav_filename = wav
|
||||
wav_filesize = os.path.getsize(wav)
|
||||
transcript_key = os.path.basename(wav)
|
||||
transcript = transcripts.loc[transcript_key, 'Transcription']
|
||||
transcript = transcripts.loc[transcript_key, "Transcription"]
|
||||
|
||||
# Some files in this dataset are truncated, the header duration
|
||||
# doesn't match the file size. This causes errors at training
|
||||
# time, so check here if things are fine before including a file
|
||||
if is_file_truncated(wav_filename, wav_filesize):
|
||||
print('Warning: File {} is corrupted, header duration does '
|
||||
'not match file size. Ignoring.'.format(wav_filename))
|
||||
print(
|
||||
"Warning: File {} is corrupted, header duration does "
|
||||
"not match file size. Ignoring.".format(wav_filename)
|
||||
)
|
||||
continue
|
||||
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
for subset in ('train', 'dev', 'test'):
|
||||
print('Loading {} set samples...'.format(subset))
|
||||
for subset in ("train", "dev", "test"):
|
||||
print("Loading {} set samples...".format(subset))
|
||||
subset_files = load_set(os.path.join(target_dir, subset))
|
||||
df = pandas.DataFrame(data=subset_files, columns=COLUMN_NAMES)
|
||||
|
||||
# Trim train set to under 10s
|
||||
if subset == 'train':
|
||||
durations = (df['wav_filesize'] - 44) / 16000 / 2
|
||||
if subset == "train":
|
||||
durations = (df["wav_filesize"] - 44) / 16000 / 2
|
||||
df = df[durations <= 10.0]
|
||||
print('Trimming {} samples > 10 seconds'.format((durations > 10.0).sum()))
|
||||
|
||||
with_noise = df['transcript'].str.contains(r'\[(FIL|SPK)\]')
|
||||
df = df[~with_noise]
|
||||
print('Trimming {} samples with noise ([FIL] or [SPK])'.format(sum(with_noise)))
|
||||
print("Trimming {} samples > 10 seconds".format((durations > 10.0).sum()))
|
||||
|
||||
dest_csv = os.path.join(target_dir, 'magicdata_{}.csv'.format(subset))
|
||||
print('Saving {} set into {}...'.format(subset, dest_csv))
|
||||
with_noise = df["transcript"].str.contains(r"\[(FIL|SPK)\]")
|
||||
df = df[~with_noise]
|
||||
print(
|
||||
"Trimming {} samples with noise ([FIL] or [SPK])".format(
|
||||
sum(with_noise)
|
||||
)
|
||||
)
|
||||
|
||||
dest_csv = os.path.join(target_dir, "magicdata_{}.csv".format(subset))
|
||||
print("Saving {} set into {}...".format(subset, dest_csv))
|
||||
df.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://openslr.org/68/
|
||||
parser = get_importers_parser(description='Import MAGICDATA corpus')
|
||||
parser.add_argument('folder_with_archives', help='Path to folder containing magicdata_{train,dev,test}.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives')
|
||||
parser = get_importers_parser(description="Import MAGICDATA corpus")
|
||||
parser.add_argument(
|
||||
"folder_with_archives",
|
||||
help="Path to folder containing magicdata_{train,dev,test}.tar.gz",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to a folder called magicdata next to the archives",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
params.target_dir = os.path.join(params.folder_with_archives, 'magicdata')
|
||||
params.target_dir = os.path.join(params.folder_with_archives, "magicdata")
|
||||
|
||||
preprocess_data(params.folder_with_archives, params.target_dir)
|
||||
|
||||
|
@ -1,25 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def extract(archive_path, target_dir):
|
||||
print('Extracting {} into {}...'.format(archive_path, target_dir))
|
||||
print("Extracting {} into {}...".format(archive_path, target_dir))
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
|
||||
@ -27,7 +21,7 @@ def extract(archive_path, target_dir):
|
||||
def preprocess_data(tgz_file, target_dir):
|
||||
# First extract main archive and sub-archives
|
||||
extract(tgz_file, target_dir)
|
||||
main_folder = os.path.join(target_dir, 'primewords_md_2018_set1')
|
||||
main_folder = os.path.join(target_dir, "primewords_md_2018_set1")
|
||||
|
||||
# Folder structure is now:
|
||||
# - primewords_md_2018_set1/
|
||||
@ -35,14 +29,11 @@ def preprocess_data(tgz_file, target_dir):
|
||||
# - [0-f]/[00-0f]/*.wav
|
||||
# - set1_transcript.json
|
||||
|
||||
transcripts_path = os.path.join(main_folder, 'set1_transcript.json')
|
||||
transcripts_path = os.path.join(main_folder, "set1_transcript.json")
|
||||
with open(transcripts_path) as fin:
|
||||
transcripts = json.load(fin)
|
||||
|
||||
transcripts = {
|
||||
entry['file']: entry['text']
|
||||
for entry in transcripts
|
||||
}
|
||||
transcripts = {entry["file"]: entry["text"] for entry in transcripts}
|
||||
|
||||
def load_set(glob_path):
|
||||
set_files = []
|
||||
@ -54,13 +45,13 @@ def preprocess_data(tgz_file, target_dir):
|
||||
transcript = transcripts[transcript_key]
|
||||
set_files.append((wav_filename, wav_filesize, transcript))
|
||||
except KeyError:
|
||||
print('Warning: Missing transcript for WAV file {}.'.format(wav))
|
||||
print("Warning: Missing transcript for WAV file {}.".format(wav))
|
||||
return set_files
|
||||
|
||||
# Load all files, then deterministically split into train/dev/test sets
|
||||
all_files = load_set(os.path.join(main_folder, 'audio_files', '*', '*', '*.wav'))
|
||||
all_files = load_set(os.path.join(main_folder, "audio_files", "*", "*", "*.wav"))
|
||||
df = pandas.DataFrame(data=all_files, columns=COLUMN_NAMES)
|
||||
df.sort_values(by='wav_filename', inplace=True)
|
||||
df.sort_values(by="wav_filename", inplace=True)
|
||||
|
||||
indices = np.arange(0, len(df))
|
||||
np.random.seed(12345)
|
||||
@ -73,29 +64,33 @@ def preprocess_data(tgz_file, target_dir):
|
||||
train_indices = indices[:-10000]
|
||||
|
||||
train_files = df.iloc[train_indices]
|
||||
durations = (train_files['wav_filesize'] - 44) / 16000 / 2
|
||||
durations = (train_files["wav_filesize"] - 44) / 16000 / 2
|
||||
train_files = train_files[durations <= 15.0]
|
||||
print('Trimming {} samples > 15 seconds'.format((durations > 15.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, 'primewords_train.csv')
|
||||
print('Saving train set into {}...'.format(dest_csv))
|
||||
print("Trimming {} samples > 15 seconds".format((durations > 15.0).sum()))
|
||||
dest_csv = os.path.join(target_dir, "primewords_train.csv")
|
||||
print("Saving train set into {}...".format(dest_csv))
|
||||
train_files.to_csv(dest_csv, index=False)
|
||||
|
||||
dev_files = df.iloc[dev_indices]
|
||||
dest_csv = os.path.join(target_dir, 'primewords_dev.csv')
|
||||
print('Saving dev set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "primewords_dev.csv")
|
||||
print("Saving dev set into {}...".format(dest_csv))
|
||||
dev_files.to_csv(dest_csv, index=False)
|
||||
|
||||
test_files = df.iloc[test_indices]
|
||||
dest_csv = os.path.join(target_dir, 'primewords_test.csv')
|
||||
print('Saving test set into {}...'.format(dest_csv))
|
||||
dest_csv = os.path.join(target_dir, "primewords_test.csv")
|
||||
print("Saving test set into {}...".format(dest_csv))
|
||||
test_files.to_csv(dest_csv, index=False)
|
||||
|
||||
|
||||
def main():
|
||||
# https://www.openslr.org/47/
|
||||
parser = get_importers_parser(description='Import Primewords Chinese corpus set 1')
|
||||
parser.add_argument('tgz_file', help='Path to primewords_md_2018_set1.tar.gz')
|
||||
parser.add_argument('--target_dir', default='', help='Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.')
|
||||
parser = get_importers_parser(description="Import Primewords Chinese corpus set 1")
|
||||
parser.add_argument("tgz_file", help="Path to primewords_md_2018_set1.tar.gz")
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default="",
|
||||
help="Target folder to extract files into and put the resulting CSVs. Defaults to same folder as the main archive.",
|
||||
)
|
||||
params = parser.parse_args()
|
||||
|
||||
if not params.target_dir:
|
||||
|
@ -1,45 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sox
|
||||
import zipfile
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import tarfile
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import unicodedata
|
||||
import zipfile
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
from util.helpers import secs_to_hours
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
|
||||
ARCHIVE_DIR_NAME = 'African_Accented_French'
|
||||
ARCHIVE_NAME = 'African_Accented_French.tar.gz'
|
||||
ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
|
||||
ARCHIVE_DIR_NAME = "African_Accented_French"
|
||||
ARCHIVE_NAME = "African_Accented_French.tar.gz"
|
||||
ARCHIVE_URL = "http://www.openslr.org/resources/57/" + ARCHIVE_NAME
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -47,10 +41,11 @@ def _download_and_preprocess_data(target_dir):
|
||||
# Produce CSV files
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -60,81 +55,89 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
else:
|
||||
print('Found directory "%s" - not extracting it from archive.' % archive_path)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
wav_filename = sample[0]
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
rows = []
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/15/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 15 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
|
||||
target_csv_template = os.path.join(
|
||||
target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace(".tar.gz", "_{}.csv")
|
||||
)
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
|
||||
wav_root_dir = os.path.join(extracted_dir)
|
||||
|
||||
all_files = [
|
||||
'transcripts/train/yaounde/fn_text.txt',
|
||||
'transcripts/train/ca16_conv/transcripts.txt',
|
||||
'transcripts/train/ca16_read/conditioned.txt',
|
||||
'transcripts/dev/niger_west_african_fr/transcripts.txt',
|
||||
'speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv',
|
||||
'transcripts/devtest/ca16_read/conditioned.txt',
|
||||
'transcripts/test/ca16/prompts.txt',
|
||||
"transcripts/train/yaounde/fn_text.txt",
|
||||
"transcripts/train/ca16_conv/transcripts.txt",
|
||||
"transcripts/train/ca16_read/conditioned.txt",
|
||||
"transcripts/dev/niger_west_african_fr/transcripts.txt",
|
||||
"speech/dev/niger_west_african_fr/niger_wav_file_name_transcript.tsv",
|
||||
"transcripts/devtest/ca16_read/conditioned.txt",
|
||||
"transcripts/test/ca16/prompts.txt",
|
||||
]
|
||||
|
||||
transcripts = {}
|
||||
for tr in all_files:
|
||||
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), 'r') as tr_source:
|
||||
with open(os.path.join(target_dir, ARCHIVE_DIR_NAME, tr), "r") as tr_source:
|
||||
for line in tr_source.readlines():
|
||||
line = line.strip()
|
||||
|
||||
if '.tsv' in tr:
|
||||
sep = ' '
|
||||
if ".tsv" in tr:
|
||||
sep = " "
|
||||
else:
|
||||
sep = ' '
|
||||
sep = " "
|
||||
|
||||
audio = os.path.basename(line.split(sep)[0])
|
||||
|
||||
if not ('.wav' in audio):
|
||||
if '.tdf' in audio:
|
||||
audio = audio.replace('.tdf', '.wav')
|
||||
if not (".wav" in audio):
|
||||
if ".tdf" in audio:
|
||||
audio = audio.replace(".tdf", ".wav")
|
||||
else:
|
||||
audio += '.wav'
|
||||
audio += ".wav"
|
||||
|
||||
transcript = ' '.join(line.split(sep)[1:])
|
||||
transcript = " ".join(line.split(sep)[1:])
|
||||
transcripts[audio] = transcript
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
samples = []
|
||||
glob_dir = os.path.join(wav_root_dir, '**/*.wav')
|
||||
glob_dir = os.path.join(wav_root_dir, "**/*.wav")
|
||||
for record in glob(glob_dir, recursive=True):
|
||||
record_file = os.path.basename(record)
|
||||
if record_file in transcripts:
|
||||
@ -156,9 +159,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -178,25 +181,38 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for African Accented French dataset. More information on http://www.openslr.org/57/.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--filter_alphabet', help='Exclude samples with characters not in provided alphabet')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser = get_importers_parser(
|
||||
description="Importer for African Accented French dataset. More information on http://www.openslr.org/57/."
|
||||
)
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--filter_alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
@ -204,9 +220,11 @@ if __name__ == "__main__":
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
label = unicodedata.normalize("NFKD", label.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
label = (
|
||||
unicodedata.normalize("NFKD", label.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
|
@ -1,44 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
|
||||
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
|
||||
# ./data/swb/swb1_LDC97S62.tgz
|
||||
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import pandas
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import codecs
|
||||
import tarfile
|
||||
import requests
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
import pandas
|
||||
import requests
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
|
||||
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
|
||||
ARCHIVE_URL = 'http://www.openslr.org/resources/5/'
|
||||
ARCHIVE_DIR_NAME = 'LDC97S62'
|
||||
LDC_DATASET = 'swb1_LDC97S62.tgz'
|
||||
ARCHIVE_NAME = "switchboard_word_alignments.tar.gz"
|
||||
ARCHIVE_URL = "http://www.openslr.org/resources/5/"
|
||||
ARCHIVE_DIR_NAME = "LDC97S62"
|
||||
LDC_DATASET = "swb1_LDC97S62.tgz"
|
||||
|
||||
|
||||
def download_file(folder, url):
|
||||
# https://stackoverflow.com/a/16696317/738515
|
||||
local_filename = url.split('/')[-1]
|
||||
local_filename = url.split("/")[-1]
|
||||
full_filename = os.path.join(folder, local_filename)
|
||||
r = requests.get(url, stream=True)
|
||||
with open(full_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
with open(full_filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
return full_filename
|
||||
|
||||
@ -46,7 +40,7 @@ def download_file(folder, url):
|
||||
def maybe_download(archive_url, target_dir, ldc_dataset):
|
||||
# If archive file does not exist, download it...
|
||||
archive_path = os.path.join(target_dir, ldc_dataset)
|
||||
ldc_path = archive_url+ldc_dataset
|
||||
ldc_path = archive_url + ldc_dataset
|
||||
if not os.path.exists(target_dir):
|
||||
print('No path "%s" - creating ...' % target_dir)
|
||||
makedirs(target_dir)
|
||||
@ -65,17 +59,23 @@ def _download_and_preprocess_data(data_dir):
|
||||
archive_path = os.path.abspath(os.path.join(data_dir, LDC_DATASET))
|
||||
|
||||
# Check swb1_LDC97S62.tgz then extract
|
||||
assert(os.path.isfile(archive_path))
|
||||
assert os.path.isfile(archive_path)
|
||||
_extract(target_dir, archive_path)
|
||||
|
||||
|
||||
# Transcripts
|
||||
transcripts_path = maybe_download(ARCHIVE_URL, target_dir, ARCHIVE_NAME)
|
||||
_extract(target_dir, transcripts_path)
|
||||
|
||||
# Check swb1_d1/2/3/4/swb_ms98_transcriptions
|
||||
expected_folders = ["swb1_d1","swb1_d2","swb1_d3","swb1_d4","swb_ms98_transcriptions"]
|
||||
assert(all([os.path.isdir(os.path.join(target_dir,e)) for e in expected_folders]))
|
||||
|
||||
expected_folders = [
|
||||
"swb1_d1",
|
||||
"swb1_d2",
|
||||
"swb1_d3",
|
||||
"swb1_d4",
|
||||
"swb_ms98_transcriptions",
|
||||
]
|
||||
assert all([os.path.isdir(os.path.join(target_dir, e)) for e in expected_folders])
|
||||
|
||||
# Conditionally convert swb sph data to wav
|
||||
_maybe_convert_wav(target_dir, "swb1_d1", "swb1_d1-wav")
|
||||
_maybe_convert_wav(target_dir, "swb1_d2", "swb1_d2-wav")
|
||||
@ -83,13 +83,21 @@ def _download_and_preprocess_data(data_dir):
|
||||
_maybe_convert_wav(target_dir, "swb1_d4", "swb1_d4-wav")
|
||||
|
||||
# Conditionally split wav data
|
||||
d1 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav")
|
||||
d2 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav")
|
||||
d3 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav")
|
||||
d4 = _maybe_split_wav_and_sentences(target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav")
|
||||
|
||||
d1 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d1-wav", "swb1_d1-split-wav"
|
||||
)
|
||||
d2 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d2-wav", "swb1_d2-split-wav"
|
||||
)
|
||||
d3 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d3-wav", "swb1_d3-split-wav"
|
||||
)
|
||||
d4 = _maybe_split_wav_and_sentences(
|
||||
target_dir, "swb_ms98_transcriptions", "swb1_d4-wav", "swb1_d4-split-wav"
|
||||
)
|
||||
|
||||
swb_files = d1.append(d2).append(d3).append(d4)
|
||||
|
||||
|
||||
train_files, dev_files, test_files = _split_sets(swb_files)
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
@ -97,7 +105,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(os.path.join(target_dir, "swb-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(target_dir, "swb-test.csv"), index=False)
|
||||
|
||||
|
||||
|
||||
def _extract(target_dir, archive_path):
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
@ -118,25 +126,46 @@ def _maybe_convert_wav(data_dir, original_data, converted_data):
|
||||
# Loop over sph files in source_dir and convert each to 16-bit PCM wav
|
||||
for root, dirnames, filenames in os.walk(source_dir):
|
||||
for filename in fnmatch.filter(filenames, "*.sph"):
|
||||
for channel in ['1', '2']:
|
||||
for channel in ["1", "2"]:
|
||||
sph_file = os.path.join(root, filename)
|
||||
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav"
|
||||
wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(target_dir, wav_filename)
|
||||
temp_wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + "-temp.wav"
|
||||
temp_wav_filename = (
|
||||
os.path.splitext(os.path.basename(sph_file))[0]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ "-temp.wav"
|
||||
)
|
||||
temp_wav_file = os.path.join(target_dir, temp_wav_filename)
|
||||
print("converting {} to {}".format(sph_file, temp_wav_file))
|
||||
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, temp_wav_file])
|
||||
subprocess.check_call(
|
||||
[
|
||||
"sph2pipe",
|
||||
"-c",
|
||||
channel,
|
||||
"-p",
|
||||
"-f",
|
||||
"rif",
|
||||
sph_file,
|
||||
temp_wav_file,
|
||||
]
|
||||
)
|
||||
print("upsampling {} to {}".format(temp_wav_file, wav_file))
|
||||
audioData, frameRate = librosa.load(temp_wav_file, sr=16000, mono=True)
|
||||
soundfile.write(wav_file, audioData, frameRate, "PCM_16")
|
||||
os.remove(temp_wav_file)
|
||||
|
||||
|
||||
|
||||
def _parse_transcriptions(trans_file):
|
||||
segments = []
|
||||
with codecs.open(trans_file, "r", "utf-8") as fin:
|
||||
for line in fin:
|
||||
if line.startswith("#") or len(line) <= 1:
|
||||
if line.startswith("#") or len(line) <= 1:
|
||||
continue
|
||||
|
||||
tokens = line.split()
|
||||
@ -150,15 +179,19 @@ def _parse_transcriptions(trans_file):
|
||||
# We need to do the encode-decode dance here because encode
|
||||
# returns a bytes() object on Python 3, and text_to_char_array
|
||||
# expects a string.
|
||||
transcript = unicodedata.normalize("NFKD", transcript) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
|
||||
segments.append({
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"transcript": transcript,
|
||||
})
|
||||
segments.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"stop_time": stop_time,
|
||||
"transcript": transcript,
|
||||
}
|
||||
)
|
||||
return segments
|
||||
|
||||
|
||||
@ -183,8 +216,16 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
|
||||
segments = _parse_transcriptions(trans_file)
|
||||
|
||||
# Open wav corresponding to transcription file
|
||||
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A']
|
||||
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav"
|
||||
channel = ("2", "1")[
|
||||
(os.path.splitext(os.path.basename(trans_file))[0])[6] == "A"
|
||||
]
|
||||
wav_filename = (
|
||||
"sw0"
|
||||
+ (os.path.splitext(os.path.basename(trans_file))[0])[2:6]
|
||||
+ "-"
|
||||
+ channel
|
||||
+ ".wav"
|
||||
)
|
||||
wav_file = os.path.join(source_dir, wav_filename)
|
||||
|
||||
print("splitting {} according to {}".format(wav_file, trans_file))
|
||||
@ -200,26 +241,39 @@ def _maybe_split_wav_and_sentences(data_dir, trans_data, original_data, converte
|
||||
# Create wav segment filename
|
||||
start_time = segment["start_time"]
|
||||
stop_time = segment["stop_time"]
|
||||
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(
|
||||
start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
os.path.splitext(os.path.basename(trans_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
if _is_wav_too_short(new_wav_filename):
|
||||
continue
|
||||
continue
|
||||
new_wav_file = os.path.join(target_dir, new_wav_filename)
|
||||
|
||||
_split_wav(origAudio, start_time, stop_time, new_wav_file)
|
||||
|
||||
new_wav_filesize = os.path.getsize(new_wav_file)
|
||||
transcript = segment["transcript"]
|
||||
files.append((os.path.abspath(new_wav_file), new_wav_filesize, transcript))
|
||||
files.append(
|
||||
(os.path.abspath(new_wav_file), new_wav_filesize, transcript)
|
||||
)
|
||||
|
||||
# Close origAudio
|
||||
origAudio.close()
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
def _is_wav_too_short(wav_filename):
|
||||
short_wav_filenames = ['sw2986A-ms98-a-trans-80.6385-83.358875.wav', 'sw2663A-ms98-a-trans-161.12025-164.213375.wav']
|
||||
short_wav_filenames = [
|
||||
"sw2986A-ms98-a-trans-80.6385-83.358875.wav",
|
||||
"sw2663A-ms98-a-trans-161.12025-164.213375.wav",
|
||||
]
|
||||
return wav_filename in short_wav_filenames
|
||||
|
||||
|
||||
@ -234,7 +288,7 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
chunkAudio.writeframes(chunkData)
|
||||
chunkAudio.close()
|
||||
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
@ -248,10 +302,24 @@ def _split_sets(filelist):
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
|
||||
return (filelist[train_beg:train_end], filelist[dev_beg:dev_end], filelist[test_beg:test_end])
|
||||
return (
|
||||
filelist[train_beg:train_end],
|
||||
filelist[dev_beg:dev_end],
|
||||
filelist[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, stride=1, offset=0, next_index=lambda i: i + 1, limit=0):
|
||||
def _read_data_set(
|
||||
filelist,
|
||||
thread_count,
|
||||
batch_size,
|
||||
numcep,
|
||||
numcontext,
|
||||
stride=1,
|
||||
offset=0,
|
||||
next_index=lambda i: i + 1,
|
||||
limit=0,
|
||||
):
|
||||
# Optionally apply dataset size limit
|
||||
if limit > 0:
|
||||
filelist = filelist.iloc[:limit]
|
||||
@ -259,7 +327,9 @@ def _read_data_set(filelist, thread_count, batch_size, numcep, numcontext, strid
|
||||
filelist = filelist[offset::stride]
|
||||
|
||||
# Return DataSet
|
||||
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index)
|
||||
return DataSet(
|
||||
txt_files, thread_count, batch_size, numcep, numcontext, next_index=next_index
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,70 +1,76 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Downloads and prepares (parts of) the "Spoken Wikipedia Corpora" for DeepSpeech.py
|
||||
Use "python3 import_swc.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
"""
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import re
|
||||
import csv
|
||||
import sox
|
||||
import wave
|
||||
import shutil
|
||||
import random
|
||||
import tarfile
|
||||
import argparse
|
||||
import progressbar
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from glob import glob
|
||||
from collections import Counter
|
||||
from glob import glob
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||
LANGUAGES = ['dutch', 'english', 'german']
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES_EXT = FIELDNAMES + ['article', 'speaker']
|
||||
LANGUAGES = ["dutch", "english", "german"]
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
FIELDNAMES_EXT = FIELDNAMES + ["article", "speaker"]
|
||||
CHANNELS = 1
|
||||
SAMPLE_RATE = 16000
|
||||
UNKNOWN = '<unknown>'
|
||||
AUDIO_PATTERN = 'audio*.ogg'
|
||||
WAV_NAME = 'audio.wav'
|
||||
ALIGNED_NAME = 'aligned.swc'
|
||||
UNKNOWN = "<unknown>"
|
||||
AUDIO_PATTERN = "audio*.ogg"
|
||||
WAV_NAME = "audio.wav"
|
||||
ALIGNED_NAME = "aligned.swc"
|
||||
|
||||
SUBSTITUTIONS = {
|
||||
'german': [
|
||||
(re.compile(r'\$'), 'dollar'),
|
||||
(re.compile(r'€'), 'euro'),
|
||||
(re.compile(r'£'), 'pfund'),
|
||||
(re.compile(r'ein tausend ([^\s]+) hundert ([^\s]+) er( |$)'), r'\1zehnhundert \2er '),
|
||||
(re.compile(r'ein tausend (acht|neun) hundert'), r'\1zehnhundert'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null punkt null null null'), 'eine milliarde'),
|
||||
(re.compile(r'punkt null null null punkt null null null punkt null null null'), 'milliarden'),
|
||||
(re.compile(r'eins punkt null null null punkt null null null'), 'eine million'),
|
||||
(re.compile(r'punkt null null null punkt null null null'), 'millionen'),
|
||||
(re.compile(r'eins punkt null null null'), 'ein tausend'),
|
||||
(re.compile(r'punkt null null null'), 'tausend'),
|
||||
(re.compile(r'punkt null'), None)
|
||||
"german": [
|
||||
(re.compile(r"\$"), "dollar"),
|
||||
(re.compile(r"€"), "euro"),
|
||||
(re.compile(r"£"), "pfund"),
|
||||
(
|
||||
re.compile(r"ein tausend ([^\s]+) hundert ([^\s]+) er( |$)"),
|
||||
r"\1zehnhundert \2er ",
|
||||
),
|
||||
(re.compile(r"ein tausend (acht|neun) hundert"), r"\1zehnhundert"),
|
||||
(
|
||||
re.compile(
|
||||
r"eins punkt null null null punkt null null null punkt null null null"
|
||||
),
|
||||
"eine milliarde",
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"punkt null null null punkt null null null punkt null null null"
|
||||
),
|
||||
"milliarden",
|
||||
),
|
||||
(re.compile(r"eins punkt null null null punkt null null null"), "eine million"),
|
||||
(re.compile(r"punkt null null null punkt null null null"), "millionen"),
|
||||
(re.compile(r"eins punkt null null null"), "ein tausend"),
|
||||
(re.compile(r"punkt null null null"), "tausend"),
|
||||
(re.compile(r"punkt null"), None),
|
||||
]
|
||||
}
|
||||
|
||||
DONT_NORMALIZE = {
|
||||
'german': 'ÄÖÜäöüß'
|
||||
}
|
||||
DONT_NORMALIZE = {"german": "ÄÖÜäöüß"}
|
||||
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys('/()[]{}<>:'))
|
||||
PRE_FILTER = str.maketrans(dict.fromkeys("/()[]{}<>:"))
|
||||
|
||||
|
||||
class Sample:
|
||||
@ -98,11 +104,14 @@ def get_sample_size(population_size):
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (margin_of_error ** 2)
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / \
|
||||
(margin_of_error ** 2 * train_size)
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
@ -111,14 +120,16 @@ def get_sample_size(population_size):
|
||||
|
||||
def maybe_download_language(language):
|
||||
lang_upper = language[0].upper() + language[1:]
|
||||
return maybe_download(SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper))
|
||||
return maybe_download(
|
||||
SWC_ARCHIVE.format(language=lang_upper),
|
||||
CLI_ARGS.base_dir,
|
||||
SWC_URL.format(language=lang_upper),
|
||||
)
|
||||
|
||||
|
||||
def maybe_extract(data_dir, extracted_data, archive):
|
||||
extracted = path.join(data_dir, extracted_data)
|
||||
if path.isdir(extracted):
|
||||
extracted = os.path.join(data_dir, extracted_data)
|
||||
if os.path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
@ -133,29 +144,29 @@ def maybe_extract(data_dir, extracted_data, archive):
|
||||
def ignored(node):
|
||||
if node is None:
|
||||
return False
|
||||
if node.tag == 'ignored':
|
||||
if node.tag == "ignored":
|
||||
return True
|
||||
return ignored(node.find('..'))
|
||||
return ignored(node.find(".."))
|
||||
|
||||
|
||||
def read_token(token):
|
||||
texts, start, end = [], None, None
|
||||
notes = token.findall('n')
|
||||
notes = token.findall("n")
|
||||
if len(notes) > 0:
|
||||
for note in notes:
|
||||
attributes = note.attrib
|
||||
if start is None and 'start' in attributes:
|
||||
start = int(attributes['start'])
|
||||
if 'end' in attributes:
|
||||
token_end = int(attributes['end'])
|
||||
if start is None and "start" in attributes:
|
||||
start = int(attributes["start"])
|
||||
if "end" in attributes:
|
||||
token_end = int(attributes["end"])
|
||||
if end is None or token_end > end:
|
||||
end = token_end
|
||||
if 'pronunciation' in attributes:
|
||||
t = attributes['pronunciation']
|
||||
if "pronunciation" in attributes:
|
||||
t = attributes["pronunciation"]
|
||||
texts.append(t)
|
||||
elif 'text' in token.attrib:
|
||||
texts.append(token.attrib['text'])
|
||||
return start, end, ' '.join(texts)
|
||||
elif "text" in token.attrib:
|
||||
texts.append(token.attrib["text"])
|
||||
return start, end, " ".join(texts)
|
||||
|
||||
|
||||
def in_alphabet(alphabet, c):
|
||||
@ -163,10 +174,12 @@ def in_alphabet(alphabet, c):
|
||||
|
||||
|
||||
ALPHABETS = {}
|
||||
|
||||
|
||||
def get_alphabet(language):
|
||||
if language in ALPHABETS:
|
||||
return ALPHABETS[language]
|
||||
alphabet_path = getattr(CLI_ARGS, language + '_alphabet')
|
||||
alphabet_path = getattr(CLI_ARGS, language + "_alphabet")
|
||||
alphabet = Alphabet(alphabet_path) if alphabet_path else None
|
||||
ALPHABETS[language] = alphabet
|
||||
return alphabet
|
||||
@ -176,27 +189,35 @@ def label_filter(label, language):
|
||||
label = label.translate(PRE_FILTER)
|
||||
label = validate_label(label)
|
||||
if label is None:
|
||||
return None, 'validation'
|
||||
return None, "validation"
|
||||
substitutions = SUBSTITUTIONS[language] if language in SUBSTITUTIONS else []
|
||||
for pattern, replacement in substitutions:
|
||||
if replacement is None:
|
||||
if pattern.match(label):
|
||||
return None, 'substitution rule'
|
||||
return None, "substitution rule"
|
||||
else:
|
||||
label = pattern.sub(replacement, label)
|
||||
chars = []
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ''
|
||||
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
|
||||
alphabet = get_alphabet(language)
|
||||
for c in label:
|
||||
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in dont_normalize
|
||||
and not in_alphabet(alphabet, c)
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
for sc in c:
|
||||
if not in_alphabet(alphabet, sc):
|
||||
return None, 'illegal character'
|
||||
return None, "illegal character"
|
||||
chars.append(sc)
|
||||
label = ''.join(chars)
|
||||
label = "".join(chars)
|
||||
label = validate_label(label)
|
||||
return label, 'validation' if label is None else None
|
||||
return label, "validation" if label is None else None
|
||||
|
||||
|
||||
def collect_samples(base_dir, language):
|
||||
@ -207,7 +228,9 @@ def collect_samples(base_dir, language):
|
||||
samples = []
|
||||
reasons = Counter()
|
||||
|
||||
def add_sample(p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason='complete'):
|
||||
def add_sample(
|
||||
p_wav_path, p_article, p_speaker, p_start, p_end, p_text, p_reason="complete"
|
||||
):
|
||||
if p_start is not None and p_end is not None and p_text is not None:
|
||||
duration = p_end - p_start
|
||||
text, filter_reason = label_filter(p_text, language)
|
||||
@ -217,53 +240,67 @@ def collect_samples(base_dir, language):
|
||||
p_reason = filter_reason
|
||||
elif CLI_ARGS.exclude_unknown_speakers and p_speaker == UNKNOWN:
|
||||
skip = True
|
||||
p_reason = 'unknown speaker'
|
||||
p_reason = "unknown speaker"
|
||||
elif CLI_ARGS.exclude_unknown_articles and p_article == UNKNOWN:
|
||||
skip = True
|
||||
p_reason = 'unknown article'
|
||||
p_reason = "unknown article"
|
||||
elif duration > CLI_ARGS.max_duration > 0 and CLI_ARGS.ignore_too_long:
|
||||
skip = True
|
||||
p_reason = 'exceeded duration'
|
||||
p_reason = "exceeded duration"
|
||||
elif int(duration / 30) < len(text):
|
||||
skip = True
|
||||
p_reason = 'too short to decode'
|
||||
p_reason = "too short to decode"
|
||||
elif duration / len(text) < 10:
|
||||
skip = True
|
||||
p_reason = 'length duration ratio'
|
||||
p_reason = "length duration ratio"
|
||||
if skip:
|
||||
reasons[p_reason] += 1
|
||||
else:
|
||||
samples.append(Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker))
|
||||
samples.append(
|
||||
Sample(p_wav_path, p_start, p_end, text, p_article, p_speaker)
|
||||
)
|
||||
elif p_start is None or p_end is None:
|
||||
reasons['missing timestamps'] += 1
|
||||
reasons["missing timestamps"] += 1
|
||||
else:
|
||||
reasons['missing text'] += 1
|
||||
reasons["missing text"] += 1
|
||||
|
||||
print('Collecting samples...')
|
||||
print("Collecting samples...")
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
for root in bar(roots):
|
||||
wav_path = path.join(root, WAV_NAME)
|
||||
wav_path = os.path.join(root, WAV_NAME)
|
||||
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
||||
article = UNKNOWN
|
||||
speaker = UNKNOWN
|
||||
for prop in aligned.iter('prop'):
|
||||
for prop in aligned.iter("prop"):
|
||||
attributes = prop.attrib
|
||||
if 'key' in attributes and 'value' in attributes:
|
||||
if attributes['key'] == 'DC.identifier':
|
||||
article = attributes['value']
|
||||
elif attributes['key'] == 'reader.name':
|
||||
speaker = attributes['value']
|
||||
for sentence in aligned.iter('s'):
|
||||
if "key" in attributes and "value" in attributes:
|
||||
if attributes["key"] == "DC.identifier":
|
||||
article = attributes["value"]
|
||||
elif attributes["key"] == "reader.name":
|
||||
speaker = attributes["value"]
|
||||
for sentence in aligned.iter("s"):
|
||||
if ignored(sentence):
|
||||
continue
|
||||
split = False
|
||||
tokens = list(map(read_token, sentence.findall('t')))
|
||||
tokens = list(map(read_token, sentence.findall("t")))
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
for token_start, token_end, token_text in tokens:
|
||||
if CLI_ARGS.exclude_numbers and any(c.isdigit() for c in token_text):
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='has numbers')
|
||||
sample_start, sample_end, token_texts, sample_texts = None, None, [], []
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="has numbers",
|
||||
)
|
||||
sample_start, sample_end, token_texts, sample_texts = (
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
continue
|
||||
if sample_start is None:
|
||||
sample_start = token_start
|
||||
@ -271,20 +308,37 @@ def collect_samples(base_dir, language):
|
||||
continue
|
||||
token_texts.append(token_text)
|
||||
if token_end is not None:
|
||||
if token_start != sample_start and token_end - sample_start > CLI_ARGS.max_duration > 0:
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='split')
|
||||
if (
|
||||
token_start != sample_start
|
||||
and token_end - sample_start > CLI_ARGS.max_duration > 0
|
||||
):
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="split",
|
||||
)
|
||||
sample_start = sample_end
|
||||
sample_texts = []
|
||||
split = True
|
||||
sample_end = token_end
|
||||
sample_texts.extend(token_texts)
|
||||
token_texts = []
|
||||
add_sample(wav_path, article, speaker, sample_start, sample_end, ' '.join(sample_texts),
|
||||
p_reason='split' if split else 'complete')
|
||||
print('Skipped samples:')
|
||||
add_sample(
|
||||
wav_path,
|
||||
article,
|
||||
speaker,
|
||||
sample_start,
|
||||
sample_end,
|
||||
" ".join(sample_texts),
|
||||
p_reason="split" if split else "complete",
|
||||
)
|
||||
print("Skipped samples:")
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - {}: {}'.format(reason, n))
|
||||
print(" - {}: {}".format(reason, n))
|
||||
return samples
|
||||
|
||||
|
||||
@ -294,8 +348,8 @@ def maybe_convert_one_to_wav(entry):
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
combiner = sox.Combiner()
|
||||
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
output_wav = path.join(root, WAV_NAME)
|
||||
if path.isfile(output_wav):
|
||||
output_wav = os.path.join(root, WAV_NAME)
|
||||
if os.path.isfile(output_wav):
|
||||
return
|
||||
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
|
||||
try:
|
||||
@ -304,18 +358,18 @@ def maybe_convert_one_to_wav(entry):
|
||||
elif len(files) > 1:
|
||||
wav_files = []
|
||||
for i, file in enumerate(files):
|
||||
wav_path = path.join(root, 'audio{}.wav'.format(i))
|
||||
wav_path = os.path.join(root, "audio{}.wav".format(i))
|
||||
transformer.build(file, wav_path)
|
||||
wav_files.append(wav_path)
|
||||
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, 'concatenate')
|
||||
combiner.set_input_format(file_type=["wav"] * len(wav_files))
|
||||
combiner.build(wav_files, output_wav, "concatenate")
|
||||
except sox.core.SoxError:
|
||||
return
|
||||
|
||||
|
||||
def maybe_convert_to_wav(base_dir):
|
||||
roots = list(os.walk(base_dir))
|
||||
print('Converting and joining source audio files...')
|
||||
print("Converting and joining source audio files...")
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
tp = ThreadPool()
|
||||
for _ in bar(tp.imap_unordered(maybe_convert_one_to_wav, roots)):
|
||||
@ -335,53 +389,66 @@ def assign_sub_sets(samples):
|
||||
sample_set.extend(speakers.pop(0))
|
||||
train_set = sum(speakers, [])
|
||||
if len(train_set) == 0:
|
||||
print('WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data')
|
||||
print(
|
||||
"WARNING: Unable to build dev and test sets without speaker bias as there is no speaker meta data"
|
||||
)
|
||||
random.seed(42) # same source data == same output
|
||||
random.shuffle(samples)
|
||||
for index, sample in enumerate(samples):
|
||||
if index < sample_size:
|
||||
sample.sub_set = 'dev'
|
||||
sample.sub_set = "dev"
|
||||
elif index < 2 * sample_size:
|
||||
sample.sub_set = 'test'
|
||||
sample.sub_set = "test"
|
||||
else:
|
||||
sample.sub_set = 'train'
|
||||
sample.sub_set = "train"
|
||||
else:
|
||||
for sub_set, sub_set_samples in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
|
||||
for sub_set, sub_set_samples in [
|
||||
("train", train_set),
|
||||
("dev", sample_sets[0]),
|
||||
("test", sample_sets[1]),
|
||||
]:
|
||||
for sample in sub_set_samples:
|
||||
sample.sub_set = sub_set
|
||||
for sub_set, sub_set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
t = sum(map(lambda s: s.end - s.start, sub_set_samples)) / (1000 * 60 * 60)
|
||||
print('Sub-set "{}" with {} samples (duration: {:.2f} h)'
|
||||
.format(sub_set, len(sub_set_samples), t))
|
||||
print(
|
||||
'Sub-set "{}" with {} samples (duration: {:.2f} h)'.format(
|
||||
sub_set, len(sub_set_samples), t
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_sample_dirs(language):
|
||||
print('Creating sample directories...')
|
||||
for set_name in ['train', 'dev', 'test']:
|
||||
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||
if not path.isdir(dir_path):
|
||||
print("Creating sample directories...")
|
||||
for set_name in ["train", "dev", "test"]:
|
||||
dir_path = os.path.join(CLI_ARGS.base_dir, language + "-" + set_name)
|
||||
if not os.path.isdir(dir_path):
|
||||
os.mkdir(dir_path)
|
||||
|
||||
|
||||
def split_audio_files(samples, language):
|
||||
print('Splitting audio files...')
|
||||
print("Splitting audio files...")
|
||||
sub_sets = Counter()
|
||||
src_wav_files = group(samples, lambda s: s.wav_path).items()
|
||||
bar = progressbar.ProgressBar(max_value=len(src_wav_files), widgets=SIMPLE_BAR)
|
||||
for wav_path, file_samples in bar(src_wav_files):
|
||||
file_samples = sorted(file_samples, key=lambda s: s.start)
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
with wave.open(wav_path, "r") as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
for sample in file_samples:
|
||||
index = sub_sets[sample.sub_set]
|
||||
sample_wav_path = path.join(CLI_ARGS.base_dir,
|
||||
language + '-' + sample.sub_set,
|
||||
'sample-{0:06d}.wav'.format(index))
|
||||
sample_wav_path = os.path.join(
|
||||
CLI_ARGS.base_dir,
|
||||
language + "-" + sample.sub_set,
|
||||
"sample-{0:06d}.wav".format(index),
|
||||
)
|
||||
sample.wav_path = sample_wav_path
|
||||
sub_sets[sample.sub_set] += 1
|
||||
src_wav_file.setpos(int(sample.start * rate / 1000.0))
|
||||
data = src_wav_file.readframes(int((sample.end - sample.start) * rate / 1000.0))
|
||||
with wave.open(sample_wav_path, 'w') as sample_wav_file:
|
||||
data = src_wav_file.readframes(
|
||||
int((sample.end - sample.start) * rate / 1000.0)
|
||||
)
|
||||
with wave.open(sample_wav_path, "w") as sample_wav_file:
|
||||
sample_wav_file.setnchannels(src_wav_file.getnchannels())
|
||||
sample_wav_file.setsampwidth(src_wav_file.getsampwidth())
|
||||
sample_wav_file.setframerate(rate)
|
||||
@ -391,22 +458,26 @@ def split_audio_files(samples, language):
|
||||
def write_csvs(samples, language):
|
||||
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
||||
base_dir = path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||
base_dir = os.path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = os.path.join(base_dir, language + "-" + sub_set + ".csv")
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
|
||||
with open(csv_path, "w") as csv_file:
|
||||
writer = csv.DictWriter(
|
||||
csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES
|
||||
)
|
||||
writer.writeheader()
|
||||
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
||||
bar = progressbar.ProgressBar(
|
||||
max_value=len(set_samples), widgets=SIMPLE_BAR
|
||||
)
|
||||
for sample in bar(set_samples):
|
||||
row = {
|
||||
'wav_filename': path.relpath(sample.wav_path, base_dir),
|
||||
'wav_filesize': path.getsize(sample.wav_path),
|
||||
'transcript': sample.text
|
||||
"wav_filename": os.path.relpath(sample.wav_path, base_dir),
|
||||
"wav_filesize": os.path.getsize(sample.wav_path),
|
||||
"transcript": sample.text,
|
||||
}
|
||||
if CLI_ARGS.add_meta:
|
||||
row['article'] = sample.article
|
||||
row['speaker'] = sample.speaker
|
||||
row["article"] = sample.article
|
||||
row["speaker"] = sample.speaker
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
@ -414,8 +485,8 @@ def cleanup(archive, language):
|
||||
if not CLI_ARGS.keep_archive:
|
||||
print('Removing archive "{}"...'.format(archive))
|
||||
os.remove(archive)
|
||||
language_dir = path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
|
||||
language_dir = os.path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
|
||||
print('Removing intermediate files in "{}"...'.format(language_dir))
|
||||
shutil.rmtree(language_dir)
|
||||
|
||||
@ -433,34 +504,75 @@ def prepare_language(language):
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import Spoken Wikipedia Corpora')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--language', default='all', help='One of (all|{})'.format('|'.join(LANGUAGES)))
|
||||
parser.add_argument('--exclude_numbers', type=bool, default=True,
|
||||
help='If sequences with non-transliterated numbers should be excluded')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--ignore_too_long', type=bool, default=False,
|
||||
help='If samples exceeding max_duration should be removed')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser = argparse.ArgumentParser(description="Import Spoken Wikipedia Corpora")
|
||||
parser.add_argument("base_dir", help="Directory containing all data")
|
||||
parser.add_argument(
|
||||
"--language", default="all", help="One of (all|{})".format("|".join(LANGUAGES))
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_numbers",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If sequences with non-transliterated numbers should be excluded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum sample duration in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_too_long",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If samples exceeding max_duration should be removed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
for language in LANGUAGES:
|
||||
parser.add_argument('--{}_alphabet'.format(language),
|
||||
help='Exclude {} samples with characters not in provided alphabet file'.format(language))
|
||||
parser.add_argument('--add_meta', action='store_true', help='Adds article and speaker CSV columns')
|
||||
parser.add_argument('--exclude_unknown_speakers', action='store_true', help='Exclude unknown speakers')
|
||||
parser.add_argument('--exclude_unknown_articles', action='store_true', help='Exclude unknown articles')
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser.add_argument('--keep_intermediate', type=bool, default=False,
|
||||
help='If intermediate files should be kept')
|
||||
parser.add_argument(
|
||||
"--{}_alphabet".format(language),
|
||||
help="Exclude {} samples with characters not in provided alphabet file".format(
|
||||
language
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_meta", action="store_true", help="Adds article and speaker CSV columns"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_unknown_speakers",
|
||||
action="store_true",
|
||||
help="Exclude unknown speakers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_unknown_articles",
|
||||
action="store_true",
|
||||
help="Exclude unknown articles",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_archive",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If downloaded archives should be kept",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_intermediate",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If intermediate files should be kept",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
if CLI_ARGS.language == 'all':
|
||||
if CLI_ARGS.language == "all":
|
||||
for lang in LANGUAGES:
|
||||
prepare_language(lang)
|
||||
elif CLI_ARGS.language in LANGUAGES:
|
||||
prepare_language(CLI_ARGS.language)
|
||||
else:
|
||||
fail('Wrong language id')
|
||||
fail("Wrong language id")
|
||||
|
@ -1,24 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import pandas
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
|
||||
from glob import glob
|
||||
from os import makedirs, path, remove, rmdir
|
||||
|
||||
import pandas
|
||||
from sox import Transformer
|
||||
from util.downloader import maybe_download
|
||||
from tensorflow.python.platform import gfile
|
||||
from util.stm import parse_stm_file
|
||||
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from deepspeech_training.util.stm import parse_stm_file
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
@ -41,6 +35,7 @@ def _download_and_preprocess_data(data_dir):
|
||||
dev_files.to_csv(path.join(data_dir, "ted-dev.csv"), index=False)
|
||||
test_files.to_csv(path.join(data_dir, "ted-test.csv"), index=False)
|
||||
|
||||
|
||||
def _maybe_extract(data_dir, extracted_data, archive):
|
||||
# If data_dir/extracted_data does not exist, extract archive in data_dir
|
||||
if not gfile.Exists(path.join(data_dir, extracted_data)):
|
||||
@ -48,6 +43,7 @@ def _maybe_extract(data_dir, extracted_data, archive):
|
||||
tar.extractall(data_dir)
|
||||
tar.close()
|
||||
|
||||
|
||||
def _maybe_convert_wav(data_dir, extracted_data):
|
||||
# Create extracted_data dir
|
||||
extracted_dir = path.join(data_dir, extracted_data)
|
||||
@ -61,6 +57,7 @@ def _maybe_convert_wav(data_dir, extracted_data):
|
||||
# Conditionally convert test sph to wav
|
||||
_maybe_convert_wav_dataset(extracted_dir, "test")
|
||||
|
||||
|
||||
def _maybe_convert_wav_dataset(extracted_dir, data_set):
|
||||
# Create source dir
|
||||
source_dir = path.join(extracted_dir, data_set, "sph")
|
||||
@ -84,6 +81,7 @@ def _maybe_convert_wav_dataset(extracted_dir, data_set):
|
||||
# Remove source_dir
|
||||
rmdir(source_dir)
|
||||
|
||||
|
||||
def _maybe_split_sentences(data_dir, extracted_data):
|
||||
# Create extracted_data dir
|
||||
extracted_dir = path.join(data_dir, extracted_data)
|
||||
@ -99,6 +97,7 @@ def _maybe_split_sentences(data_dir, extracted_data):
|
||||
|
||||
return train_files, dev_files, test_files
|
||||
|
||||
|
||||
def _maybe_split_dataset(extracted_dir, data_set):
|
||||
# Create stm dir
|
||||
stm_dir = path.join(extracted_dir, data_set, "stm")
|
||||
@ -116,14 +115,21 @@ def _maybe_split_dataset(extracted_dir, data_set):
|
||||
# Open wav corresponding to stm_file
|
||||
wav_filename = path.splitext(path.basename(stm_file))[0] + ".wav"
|
||||
wav_file = path.join(wav_dir, wav_filename)
|
||||
origAudio = wave.open(wav_file,'r')
|
||||
origAudio = wave.open(wav_file, "r")
|
||||
|
||||
# Loop over stm_segments and split wav_file for each segment
|
||||
for stm_segment in stm_segments:
|
||||
# Create wav segment filename
|
||||
start_time = stm_segment.start_time
|
||||
stop_time = stm_segment.stop_time
|
||||
new_wav_filename = path.splitext(path.basename(stm_file))[0] + "-" + str(start_time) + "-" + str(stop_time) + ".wav"
|
||||
new_wav_filename = (
|
||||
path.splitext(path.basename(stm_file))[0]
|
||||
+ "-"
|
||||
+ str(start_time)
|
||||
+ "-"
|
||||
+ str(stop_time)
|
||||
+ ".wav"
|
||||
)
|
||||
new_wav_file = path.join(wav_dir, new_wav_filename)
|
||||
|
||||
# If the wav segment filename does not exist create it
|
||||
@ -131,23 +137,29 @@ def _maybe_split_dataset(extracted_dir, data_set):
|
||||
_split_wav(origAudio, start_time, stop_time, new_wav_file)
|
||||
|
||||
new_wav_filesize = path.getsize(new_wav_file)
|
||||
files.append((path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript))
|
||||
files.append(
|
||||
(path.abspath(new_wav_file), new_wav_filesize, stm_segment.transcript)
|
||||
)
|
||||
|
||||
# Close origAudio
|
||||
origAudio.close()
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
|
||||
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
frameRate = origAudio.getframerate()
|
||||
origAudio.setpos(int(start_time*frameRate))
|
||||
chunkData = origAudio.readframes(int((stop_time - start_time)*frameRate))
|
||||
chunkAudio = wave.open(new_wav_file,'w')
|
||||
origAudio.setpos(int(start_time * frameRate))
|
||||
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
|
||||
chunkAudio = wave.open(new_wav_file, "w")
|
||||
chunkAudio.setnchannels(origAudio.getnchannels())
|
||||
chunkAudio.setsampwidth(origAudio.getsampwidth())
|
||||
chunkAudio.setframerate(frameRate)
|
||||
chunkAudio.writeframes(chunkData)
|
||||
chunkAudio.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
'''
|
||||
"""
|
||||
NAME : LDC TIMIT Dataset
|
||||
URL : https://catalog.ldc.upenn.edu/ldc93s1
|
||||
HOURS : 5
|
||||
@ -8,29 +8,32 @@
|
||||
AUTHORS : Garofolo, John, et al.
|
||||
TYPE : LDC Membership
|
||||
LICENCE : LDC User Agreement
|
||||
'''
|
||||
"""
|
||||
|
||||
import errno
|
||||
import fnmatch
|
||||
import os
|
||||
from os import path
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import fnmatch
|
||||
from os import path
|
||||
|
||||
import pandas as pd
|
||||
import subprocess
|
||||
|
||||
|
||||
def clean(word):
|
||||
# LC ALL & strip punctuation which are not required
|
||||
new = word.lower().replace('.', '')
|
||||
new = new.replace(',', '')
|
||||
new = new.replace(';', '')
|
||||
new = new.replace('"', '')
|
||||
new = new.replace('!', '')
|
||||
new = new.replace('?', '')
|
||||
new = new.replace(':', '')
|
||||
new = new.replace('-', '')
|
||||
new = word.lower().replace(".", "")
|
||||
new = new.replace(",", "")
|
||||
new = new.replace(";", "")
|
||||
new = new.replace('"', "")
|
||||
new = new.replace("!", "")
|
||||
new = new.replace("?", "")
|
||||
new = new.replace(":", "")
|
||||
new = new.replace("-", "")
|
||||
return new
|
||||
|
||||
|
||||
def _preprocess_data(args):
|
||||
|
||||
# Assume data is downloaded from LDC - https://catalog.ldc.upenn.edu/ldc93s1
|
||||
@ -40,16 +43,24 @@ def _preprocess_data(args):
|
||||
|
||||
if ignoreSASentences:
|
||||
print("Using recommended ignore SA sentences")
|
||||
print("Ignoring SA sentences (2 x sentences which are repeated by all speakers)")
|
||||
print(
|
||||
"Ignoring SA sentences (2 x sentences which are repeated by all speakers)"
|
||||
)
|
||||
else:
|
||||
print("Using unrecommended setting to include SA sentences")
|
||||
|
||||
datapath = args
|
||||
target = path.join(datapath, "TIMIT")
|
||||
print("Checking to see if data has already been extracted in given argument: %s", target)
|
||||
print(
|
||||
"Checking to see if data has already been extracted in given argument: %s",
|
||||
target,
|
||||
)
|
||||
|
||||
if not path.isdir(target):
|
||||
print("Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ", datapath)
|
||||
print(
|
||||
"Could not find extracted data, trying to find: TIMIT-LDC93S1.tgz in: ",
|
||||
datapath,
|
||||
)
|
||||
filepath = path.join(datapath, "TIMIT-LDC93S1.tgz")
|
||||
if path.isfile(filepath):
|
||||
print("File found, extracting")
|
||||
@ -103,40 +114,58 @@ def _preprocess_data(args):
|
||||
# if ignoreSAsentences we only want those without SA in the name
|
||||
# OR
|
||||
# if not ignoreSAsentences we want all to be added
|
||||
if (ignoreSASentences and not ('SA' in os.path.basename(full_wav))) or (not ignoreSASentences):
|
||||
if 'train' in full_wav.lower():
|
||||
if (ignoreSASentences and not ("SA" in os.path.basename(full_wav))) or (
|
||||
not ignoreSASentences
|
||||
):
|
||||
if "train" in full_wav.lower():
|
||||
train_list_wavs.append(full_wav)
|
||||
train_list_trans.append(trans)
|
||||
train_list_size.append(wav_filesize)
|
||||
elif 'test' in full_wav.lower():
|
||||
elif "test" in full_wav.lower():
|
||||
test_list_wavs.append(full_wav)
|
||||
test_list_trans.append(trans)
|
||||
test_list_size.append(wav_filesize)
|
||||
else:
|
||||
raise IOError
|
||||
|
||||
a = {'wav_filename': train_list_wavs,
|
||||
'wav_filesize': train_list_size,
|
||||
'transcript': train_list_trans
|
||||
}
|
||||
a = {
|
||||
"wav_filename": train_list_wavs,
|
||||
"wav_filesize": train_list_size,
|
||||
"transcript": train_list_trans,
|
||||
}
|
||||
|
||||
c = {'wav_filename': test_list_wavs,
|
||||
'wav_filesize': test_list_size,
|
||||
'transcript': test_list_trans
|
||||
}
|
||||
c = {
|
||||
"wav_filename": test_list_wavs,
|
||||
"wav_filesize": test_list_size,
|
||||
"transcript": test_list_trans,
|
||||
}
|
||||
|
||||
all = {'wav_filename': train_list_wavs + test_list_wavs,
|
||||
'wav_filesize': train_list_size + test_list_size,
|
||||
'transcript': train_list_trans + test_list_trans
|
||||
}
|
||||
all = {
|
||||
"wav_filename": train_list_wavs + test_list_wavs,
|
||||
"wav_filesize": train_list_size + test_list_size,
|
||||
"transcript": train_list_trans + test_list_trans,
|
||||
}
|
||||
|
||||
df_all = pd.DataFrame(all, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_train = pd.DataFrame(a, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_test = pd.DataFrame(c, columns=['wav_filename', 'wav_filesize', 'transcript'], dtype=int)
|
||||
df_all = pd.DataFrame(
|
||||
all, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
df_train = pd.DataFrame(
|
||||
a, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
df_test = pd.DataFrame(
|
||||
c, columns=["wav_filename", "wav_filesize", "transcript"], dtype=int
|
||||
)
|
||||
|
||||
df_all.to_csv(
|
||||
target + "/timit_all.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
df_train.to_csv(
|
||||
target + "/timit_train.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
df_test.to_csv(
|
||||
target + "/timit_test.csv", sep=",", header=True, index=False, encoding="ascii"
|
||||
)
|
||||
|
||||
df_all.to_csv(target+"/timit_all.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
df_train.to_csv(target+"/timit_train.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
df_test.to_csv(target+"/timit_test.csv", sep=',', header=True, index=False, encoding='ascii')
|
||||
|
||||
if __name__ == "__main__":
|
||||
_preprocess_data(sys.argv[1])
|
||||
|
150
bin/import_ts.py
150
bin/import_ts.py
@ -1,52 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import unidecode
|
||||
import zipfile
|
||||
import sox
|
||||
import subprocess
|
||||
import progressbar
|
||||
|
||||
import zipfile
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
import progressbar
|
||||
import sox
|
||||
|
||||
from util.downloader import maybe_download
|
||||
import unidecode
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
get_importers_parser,
|
||||
get_validate_label,
|
||||
print_import_report,
|
||||
)
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 15
|
||||
ARCHIVE_NAME = '2019-04-11_fr_FR'
|
||||
ARCHIVE_DIR_NAME = 'ts_' + ARCHIVE_NAME
|
||||
ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE_NAME + '.zip'
|
||||
ARCHIVE_NAME = "2019-04-11_fr_FR"
|
||||
ARCHIVE_DIR_NAME = "ts_" + ARCHIVE_NAME
|
||||
ARCHIVE_URL = (
|
||||
"https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/" + ARCHIVE_NAME + ".zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
|
||||
archive_path = maybe_download(
|
||||
"ts_" + ARCHIVE_NAME + ".zip", target_dir, ARCHIVE_URL
|
||||
)
|
||||
# Conditionally extract archive data
|
||||
_maybe_extract(target_dir, ARCHIVE_DIR_NAME, archive_path)
|
||||
# Conditionally convert TrainingSpeech data to DeepSpeech CSVs and wav
|
||||
_maybe_convert_sets(target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible)
|
||||
_maybe_convert_sets(
|
||||
target_dir, ARCHIVE_DIR_NAME, english_compatible=english_compatible
|
||||
)
|
||||
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -58,16 +59,20 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
orig_filename = sample['path']
|
||||
orig_filename = sample["path"]
|
||||
# Storing wav files next to the wav ones - just with a different suffix
|
||||
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
_maybe_convert_wav(orig_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = sample['text']
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(
|
||||
subprocess.check_output(
|
||||
["soxi", "-s", wav_filename], stderr=subprocess.STDOUT
|
||||
)
|
||||
)
|
||||
label = sample["text"]
|
||||
|
||||
rows = []
|
||||
|
||||
@ -75,40 +80,41 @@ def one_sample(sample):
|
||||
counter = get_counter()
|
||||
if file_size == -1:
|
||||
# Excluding samples that failed upon conversion
|
||||
counter['failed'] += 1
|
||||
counter["failed"] += 1
|
||||
elif label is None:
|
||||
# Excluding samples that failed on label validation
|
||||
counter['invalid_label'] += 1
|
||||
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)):
|
||||
counter["invalid_label"] += 1
|
||||
elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
|
||||
# Excluding samples that are too short to fit the transcript
|
||||
counter['too_short'] += 1
|
||||
elif frames/SAMPLE_RATE > MAX_SECS:
|
||||
counter["too_short"] += 1
|
||||
elif frames / SAMPLE_RATE > MAX_SECS:
|
||||
# Excluding very long samples to keep a reasonable batch-size
|
||||
counter['too_long'] += 1
|
||||
counter["too_long"] += 1
|
||||
else:
|
||||
# This one is good - keep it for the target CSV
|
||||
rows.append((wav_filename, file_size, label))
|
||||
counter['all'] += 1
|
||||
counter['total_time'] += frames
|
||||
counter["all"] += 1
|
||||
counter["total_time"] += frames
|
||||
|
||||
return (counter, rows)
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
|
||||
target_csv_template = os.path.join(target_dir, "ts_" + ARCHIVE_NAME + "_{}.csv")
|
||||
if os.path.isfile(target_csv_template):
|
||||
return
|
||||
path_to_original_csv = os.path.join(extracted_dir, 'data.csv')
|
||||
path_to_original_csv = os.path.join(extracted_dir, "data.csv")
|
||||
with open(path_to_original_csv) as csv_f:
|
||||
data = [
|
||||
d for d in csv.DictReader(csv_f, delimiter=',')
|
||||
if float(d['duration']) <= MAX_SECS
|
||||
d
|
||||
for d in csv.DictReader(csv_f, delimiter=",")
|
||||
if float(d["duration"]) <= MAX_SECS
|
||||
]
|
||||
|
||||
for line in data:
|
||||
line['path'] = os.path.join(extracted_dir, line['path'])
|
||||
line["path"] = os.path.join(extracted_dir, line["path"])
|
||||
|
||||
num_samples = len(data)
|
||||
rows = []
|
||||
@ -125,9 +131,9 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
with open(target_csv_template.format('train'), 'w') as train_csv_file: # 80%
|
||||
with open(target_csv_template.format('dev'), 'w') as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format('test'), 'w') as test_csv_file: # 10%
|
||||
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
|
||||
with open(target_csv_template.format("dev"), "w") as dev_csv_file: # 10%
|
||||
with open(target_csv_template.format("test"), "w") as test_csv_file: # 10%
|
||||
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
|
||||
train_writer.writeheader()
|
||||
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
|
||||
@ -136,7 +142,11 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
test_writer.writeheader()
|
||||
|
||||
for i, item in enumerate(rows):
|
||||
transcript = validate_label(cleanup_transcript(item[2], english_compatible=english_compatible))
|
||||
transcript = validate_label(
|
||||
cleanup_transcript(
|
||||
item[2], english_compatible=english_compatible
|
||||
)
|
||||
)
|
||||
if not transcript:
|
||||
continue
|
||||
wav_filename = os.path.join(target_dir, extracted_data, item[0])
|
||||
@ -147,45 +157,53 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
writer = dev_writer
|
||||
else:
|
||||
writer = train_writer
|
||||
writer.writerow(dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
))
|
||||
writer.writerow(
|
||||
dict(
|
||||
wav_filename=wav_filename,
|
||||
wav_filesize=os.path.getsize(wav_filename),
|
||||
transcript=transcript,
|
||||
)
|
||||
)
|
||||
|
||||
imported_samples = get_imported_samples(counter)
|
||||
assert counter['all'] == num_samples
|
||||
assert counter["all"] == num_samples
|
||||
assert len(rows) == imported_samples
|
||||
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
|
||||
def _maybe_convert_wav(orig_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
transformer.build(orig_filename, wav_filename)
|
||||
except sox.core.SoxError as ex:
|
||||
print('SoX processing error', ex, orig_filename, wav_filename)
|
||||
print("SoX processing error", ex, orig_filename, wav_filename)
|
||||
|
||||
|
||||
PUNCTUATIONS_REG = re.compile(r"[°\-,;!?.()\[\]*…—]")
|
||||
MULTIPLE_SPACES_REG = re.compile(r'\s{2,}')
|
||||
MULTIPLE_SPACES_REG = re.compile(r"\s{2,}")
|
||||
|
||||
|
||||
def cleanup_transcript(text, english_compatible=False):
|
||||
text = text.replace('’', "'").replace('\u00A0', ' ')
|
||||
text = PUNCTUATIONS_REG.sub(' ', text)
|
||||
text = MULTIPLE_SPACES_REG.sub(' ', text)
|
||||
text = text.replace("’", "'").replace("\u00A0", " ")
|
||||
text = PUNCTUATIONS_REG.sub(" ", text)
|
||||
text = MULTIPLE_SPACES_REG.sub(" ", text)
|
||||
if english_compatible:
|
||||
text = unidecode.unidecode(text)
|
||||
return text.strip().lower()
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = get_importers_parser(description='Importer for TrainingSpeech dataset.')
|
||||
parser.add_argument(dest='target_dir')
|
||||
parser.add_argument('--english-compatible', action='store_true', dest='english_compatible', help='Remove diactrics and other non-ascii chars.')
|
||||
parser = get_importers_parser(description="Importer for TrainingSpeech dataset.")
|
||||
parser.add_argument(dest="target_dir")
|
||||
parser.add_argument(
|
||||
"--english-compatible",
|
||||
action="store_true",
|
||||
dest="english_compatible",
|
||||
help="Remove diactrics and other non-ascii chars.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,45 +1,40 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
"""
|
||||
Downloads and prepares (parts of) the "German Distant Speech" corpus (TUDA) for DeepSpeech.py
|
||||
Use "python3 import_tuda.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import wave
|
||||
import tarfile
|
||||
"""
|
||||
import argparse
|
||||
import progressbar
|
||||
import csv
|
||||
import os
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from collections import Counter
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
TUDA_VERSION = 'v2'
|
||||
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
|
||||
TUDA_URL = 'http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz'.format(TUDA_PACKAGE)
|
||||
TUDA_ARCHIVE = '{}.tar.gz'.format(TUDA_PACKAGE)
|
||||
import progressbar
|
||||
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
TUDA_VERSION = "v2"
|
||||
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
|
||||
TUDA_URL = "http://ltdata1.informatik.uni-hamburg.de/kaldi_tuda_de/{}.tar.gz".format(
|
||||
TUDA_PACKAGE
|
||||
)
|
||||
TUDA_ARCHIVE = "{}.tar.gz".format(TUDA_PACKAGE)
|
||||
|
||||
CHANNELS = 1
|
||||
SAMPLE_WIDTH = 2
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
|
||||
|
||||
|
||||
def maybe_extract(archive):
|
||||
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||
if path.isdir(extracted):
|
||||
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||
if os.path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
@ -52,86 +47,100 @@ def maybe_extract(archive):
|
||||
|
||||
|
||||
def check_and_prepare_sentence(sentence):
|
||||
sentence = sentence.lower().replace('co2', 'c o zwei')
|
||||
sentence = sentence.lower().replace("co2", "c o zwei")
|
||||
chars = []
|
||||
for c in sentence:
|
||||
if CLI_ARGS.normalize and c not in 'äöüß' and (ALPHABET is None or not ALPHABET.has_char(c)):
|
||||
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore")
|
||||
if (
|
||||
CLI_ARGS.normalize
|
||||
and c not in "äöüß"
|
||||
and (ALPHABET is None or not ALPHABET.has_char(c))
|
||||
):
|
||||
c = (
|
||||
unicodedata.normalize("NFKD", c)
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
for sc in c:
|
||||
if ALPHABET is not None and not ALPHABET.has_char(c):
|
||||
return None
|
||||
chars.append(sc)
|
||||
return validate_label(''.join(chars))
|
||||
return validate_label("".join(chars))
|
||||
|
||||
|
||||
def check_wav_file(wav_path, sentence): # pylint: disable=too-many-return-statements
|
||||
try:
|
||||
with wave.open(wav_path, 'r') as src_wav_file:
|
||||
with wave.open(wav_path, "r") as src_wav_file:
|
||||
rate = src_wav_file.getframerate()
|
||||
channels = src_wav_file.getnchannels()
|
||||
sample_width = src_wav_file.getsampwidth()
|
||||
milliseconds = int(src_wav_file.getnframes() * 1000 / rate)
|
||||
if rate != SAMPLE_RATE:
|
||||
return False, 'wrong sample rate'
|
||||
return False, "wrong sample rate"
|
||||
if channels != CHANNELS:
|
||||
return False, 'wrong number of channels'
|
||||
return False, "wrong number of channels"
|
||||
if sample_width != SAMPLE_WIDTH:
|
||||
return False, 'wrong sample width'
|
||||
return False, "wrong sample width"
|
||||
if milliseconds / len(sentence) < 30:
|
||||
return False, 'too short'
|
||||
return False, "too short"
|
||||
if milliseconds > CLI_ARGS.max_duration > 0:
|
||||
return False, 'too long'
|
||||
return False, "too long"
|
||||
except wave.Error:
|
||||
return False, 'invalid wav file'
|
||||
return False, "invalid wav file"
|
||||
except EOFError:
|
||||
return False, 'premature EOF'
|
||||
return True, 'OK'
|
||||
return False, "premature EOF"
|
||||
return True, "OK"
|
||||
|
||||
|
||||
def write_csvs(extracted):
|
||||
sample_counter = 0
|
||||
reasons = Counter()
|
||||
for sub_set in ['train', 'dev', 'test']:
|
||||
set_path = path.join(extracted, sub_set)
|
||||
for sub_set in ["train", "dev", "test"]:
|
||||
set_path = os.path.join(extracted, sub_set)
|
||||
set_files = os.listdir(set_path)
|
||||
recordings = {}
|
||||
for file in set_files:
|
||||
if file.endswith('.xml'):
|
||||
if file.endswith(".xml"):
|
||||
recordings[file[:-4]] = []
|
||||
for file in set_files:
|
||||
if file.endswith('.wav') and '_' in file:
|
||||
prefix = file.split('_')[0]
|
||||
if file.endswith(".wav") and "_" in file:
|
||||
prefix = file.split("_")[0]
|
||||
if prefix in recordings:
|
||||
recordings[prefix].append(file)
|
||||
recordings = recordings.items()
|
||||
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
||||
csv_path = os.path.join(
|
||||
CLI_ARGS.base_dir, "tuda-{}-{}.csv".format(TUDA_VERSION, sub_set)
|
||||
)
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
with open(csv_path, "w") as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
set_dir = path.join(extracted, sub_set)
|
||||
set_dir = os.path.join(extracted, sub_set)
|
||||
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
|
||||
for prefix, wav_names in bar(recordings):
|
||||
xml_path = path.join(set_dir, prefix + '.xml')
|
||||
xml_path = os.path.join(set_dir, prefix + ".xml")
|
||||
meta = ET.parse(xml_path).getroot()
|
||||
sentence = list(meta.iter('cleaned_sentence'))[0].text
|
||||
sentence = list(meta.iter("cleaned_sentence"))[0].text
|
||||
sentence = check_and_prepare_sentence(sentence)
|
||||
if sentence is None:
|
||||
continue
|
||||
for wav_name in wav_names:
|
||||
sample_counter += 1
|
||||
wav_path = path.join(set_path, wav_name)
|
||||
wav_path = os.path.join(set_path, wav_name)
|
||||
keep, reason = check_wav_file(wav_path, sentence)
|
||||
if keep:
|
||||
writer.writerow({
|
||||
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
|
||||
'wav_filesize': path.getsize(wav_path),
|
||||
'transcript': sentence.lower()
|
||||
})
|
||||
writer.writerow(
|
||||
{
|
||||
"wav_filename": os.path.relpath(
|
||||
wav_path, CLI_ARGS.base_dir
|
||||
),
|
||||
"wav_filesize": os.path.getsize(wav_path),
|
||||
"transcript": sentence.lower(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
reasons[reason] += 1
|
||||
if len(reasons.keys()) > 0:
|
||||
print('Excluded samples:')
|
||||
print("Excluded samples:")
|
||||
for reason, n in reasons.most_common():
|
||||
print(' - "{}": {} ({:.2f}%)'.format(reason, n, n * 100 / sample_counter))
|
||||
|
||||
@ -150,13 +159,29 @@ def download_and_prepare():
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Import German Distant Speech (TUDA)')
|
||||
parser.add_argument('base_dir', help='Directory containing all data')
|
||||
parser.add_argument('--max_duration', type=int, default=10000, help='Maximum sample duration in milliseconds')
|
||||
parser.add_argument('--normalize', action='store_true', help='Converts diacritic characters to their base ones')
|
||||
parser.add_argument('--alphabet', help='Exclude samples with characters not in provided alphabet file')
|
||||
parser.add_argument('--keep_archive', type=bool, default=True,
|
||||
help='If downloaded archives should be kept')
|
||||
parser = argparse.ArgumentParser(description="Import German Distant Speech (TUDA)")
|
||||
parser.add_argument("base_dir", help="Directory containing all data")
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum sample duration in milliseconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Converts diacritic characters to their base ones",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alphabet",
|
||||
help="Exclude samples with characters not in provided alphabet file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_archive",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If downloaded archives should be kept",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -1,29 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
|
||||
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
|
||||
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], ".."))
|
||||
|
||||
from util.importers import get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import re
|
||||
from multiprocessing import Pool
|
||||
from zipfile import ZipFile
|
||||
|
||||
import librosa
|
||||
import progressbar
|
||||
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
from zipfile import ZipFile
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR, maybe_download
|
||||
from deepspeech_training.util.importers import (
|
||||
get_counter,
|
||||
get_imported_samples,
|
||||
print_import_report,
|
||||
)
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
@ -37,7 +30,7 @@ ARCHIVE_URL = (
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract common voice data
|
||||
@ -48,8 +41,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print(f"No directory {extracted_path} - extracting archive...")
|
||||
with ZipFile(archive_path, "r") as zipobj:
|
||||
# Extract all the contents of zip file in current directory
|
||||
@ -59,15 +52,17 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data, "wav48")
|
||||
txt_dir = path.join(target_dir, extracted_data, "txt")
|
||||
extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
|
||||
txt_dir = os.path.join(target_dir, extracted_data, "txt")
|
||||
|
||||
directory = os.path.expanduser(extracted_dir)
|
||||
srtd = len(sorted(os.listdir(directory)))
|
||||
all_samples = []
|
||||
|
||||
for target in sorted(os.listdir(directory)):
|
||||
all_samples += _maybe_prepare_set(path.join(extracted_dir, os.path.split(target)[-1]))
|
||||
all_samples += _maybe_prepare_set(
|
||||
path.join(extracted_dir, os.path.split(target)[-1])
|
||||
)
|
||||
|
||||
num_samples = len(all_samples)
|
||||
print(f"Converting wav files to {SAMPLE_RATE}hz...")
|
||||
@ -81,6 +76,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
|
||||
_write_csv(extracted_dir, txt_dir, target_dir)
|
||||
|
||||
|
||||
def one_sample(sample):
|
||||
if is_audio_file(sample):
|
||||
y, sr = librosa.load(sample, sr=16000)
|
||||
@ -103,6 +99,7 @@ def _maybe_prepare_set(target_csv):
|
||||
samples = new_samples
|
||||
return samples
|
||||
|
||||
|
||||
def _write_csv(extracted_dir, txt_dir, target_dir):
|
||||
print(f"Writing CSV file")
|
||||
dset_abs_path = extracted_dir
|
||||
@ -197,7 +194,9 @@ AUDIO_EXTENSIONS = [".wav", "WAV"]
|
||||
|
||||
|
||||
def is_audio_file(filepath):
|
||||
return any(os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS)
|
||||
return any(
|
||||
os.path.basename(filepath).endswith(extension) for extension in AUDIO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,24 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import tarfile
|
||||
import pandas
|
||||
import re
|
||||
import unicodedata
|
||||
import tarfile
|
||||
import threading
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
from six.moves import urllib
|
||||
import unicodedata
|
||||
import urllib
|
||||
from glob import glob
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from os import makedirs, path
|
||||
|
||||
import pandas
|
||||
from bs4 import BeautifulSoup
|
||||
from tensorflow.python.platform import gfile
|
||||
from util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
"""The number of jobs to run in parallel"""
|
||||
NUM_PARALLEL = 8
|
||||
@ -26,8 +21,10 @@ NUM_PARALLEL = 8
|
||||
"""Lambda function returns the filename of a path"""
|
||||
filename_of = lambda x: path.split(x)[1]
|
||||
|
||||
|
||||
class AtomicCounter(object):
|
||||
"""A class that atomically increments a counter"""
|
||||
|
||||
def __init__(self, start_count=0):
|
||||
"""Initialize the counter
|
||||
:param start_count: the number to start counting at
|
||||
@ -50,6 +47,7 @@ class AtomicCounter(object):
|
||||
"""Returns the current value of the counter (not atomic)"""
|
||||
return self.__count
|
||||
|
||||
|
||||
def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
"""Generate a function to download a file based on given parameters
|
||||
This works by currying the above given arguments into a closure
|
||||
@ -61,6 +59,7 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
:param counter: an atomic counter to keep track of # of downloaded files
|
||||
:return: a function that actually downloads a file given these params
|
||||
"""
|
||||
|
||||
def download(d):
|
||||
"""Binds voxforge_url, archive_dir, total, and counter into this scope
|
||||
Downloads the given file
|
||||
@ -68,12 +67,14 @@ def _parallel_downloader(voxforge_url, archive_dir, total, counter):
|
||||
of the file to download and file is the name of the file to download
|
||||
"""
|
||||
(i, file) = d
|
||||
download_url = voxforge_url + '/' + file
|
||||
download_url = voxforge_url + "/" + file
|
||||
c = counter.increment()
|
||||
print('Downloading file {} ({}/{})...'.format(i+1, c, total))
|
||||
print("Downloading file {} ({}/{})...".format(i + 1, c, total))
|
||||
maybe_download(filename_of(download_url), archive_dir, download_url)
|
||||
|
||||
return download
|
||||
|
||||
|
||||
def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter):
|
||||
"""Generate a function to extract a tar file based on given parameters
|
||||
This works by currying the above given arguments into a closure
|
||||
@ -86,6 +87,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
:param counter: an atomic counter to keep track of # of extracted files
|
||||
:return: a function that actually extracts a tar file given these params
|
||||
"""
|
||||
|
||||
def extract(d):
|
||||
"""Binds data_dir, number_of_test, number_of_dev, total, and counter into this scope
|
||||
Extracts the given file
|
||||
@ -95,58 +97,74 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
(i, archive) = d
|
||||
if i < number_of_test:
|
||||
dataset_dir = path.join(data_dir, "test")
|
||||
elif i<number_of_test+number_of_dev:
|
||||
elif i < number_of_test + number_of_dev:
|
||||
dataset_dir = path.join(data_dir, "dev")
|
||||
else:
|
||||
dataset_dir = path.join(data_dir, "train")
|
||||
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
||||
if not gfile.Exists(
|
||||
os.path.join(dataset_dir, ".".join(filename_of(archive).split(".")[:-1]))
|
||||
):
|
||||
c = counter.increment()
|
||||
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
|
||||
print("Extracting file {} ({}/{})...".format(i + 1, c, total))
|
||||
tar = tarfile.open(archive)
|
||||
tar.extractall(dataset_dir)
|
||||
tar.close()
|
||||
|
||||
return extract
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data to data_dir
|
||||
if not path.isdir(data_dir):
|
||||
makedirs(data_dir)
|
||||
|
||||
archive_dir = data_dir+"/archive"
|
||||
archive_dir = data_dir + "/archive"
|
||||
if not path.isdir(archive_dir):
|
||||
makedirs(archive_dir)
|
||||
|
||||
print("Downloading Voxforge data set into {} if not already present...".format(archive_dir))
|
||||
print(
|
||||
"Downloading Voxforge data set into {} if not already present...".format(
|
||||
archive_dir
|
||||
)
|
||||
)
|
||||
|
||||
voxforge_url = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit'
|
||||
voxforge_url = "http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/Audio/Main/16kHz_16bit"
|
||||
html_page = urllib.request.urlopen(voxforge_url)
|
||||
soup = BeautifulSoup(html_page, 'html.parser')
|
||||
soup = BeautifulSoup(html_page, "html.parser")
|
||||
|
||||
# list all links
|
||||
refs = [l['href'] for l in soup.find_all('a') if ".tgz" in l['href']]
|
||||
refs = [l["href"] for l in soup.find_all("a") if ".tgz" in l["href"]]
|
||||
|
||||
# download files in parallel
|
||||
print('{} files to download'.format(len(refs)))
|
||||
downloader = _parallel_downloader(voxforge_url, archive_dir, len(refs), AtomicCounter())
|
||||
print("{} files to download".format(len(refs)))
|
||||
downloader = _parallel_downloader(
|
||||
voxforge_url, archive_dir, len(refs), AtomicCounter()
|
||||
)
|
||||
p = ThreadPool(NUM_PARALLEL)
|
||||
p.map(downloader, enumerate(refs))
|
||||
|
||||
# Conditionally extract data to dataset_dir
|
||||
if not path.isdir(path.join(data_dir,"test")):
|
||||
makedirs(path.join(data_dir,"test"))
|
||||
if not path.isdir(path.join(data_dir,"dev")):
|
||||
makedirs(path.join(data_dir,"dev"))
|
||||
if not path.isdir(path.join(data_dir,"train")):
|
||||
makedirs(path.join(data_dir,"train"))
|
||||
if not path.isdir(os.path.join(data_dir, "test")):
|
||||
makedirs(os.path.join(data_dir, "test"))
|
||||
if not path.isdir(os.path.join(data_dir, "dev")):
|
||||
makedirs(os.path.join(data_dir, "dev"))
|
||||
if not path.isdir(os.path.join(data_dir, "train")):
|
||||
makedirs(os.path.join(data_dir, "train"))
|
||||
|
||||
tarfiles = glob(path.join(archive_dir, "*.tgz"))
|
||||
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
|
||||
number_of_files = len(tarfiles)
|
||||
number_of_test = number_of_files//100
|
||||
number_of_dev = number_of_files//100
|
||||
number_of_test = number_of_files // 100
|
||||
number_of_dev = number_of_files // 100
|
||||
|
||||
# extract tars in parallel
|
||||
print("Extracting Voxforge data set into {} if not already present...".format(data_dir))
|
||||
extracter = _parallel_extracter(data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter())
|
||||
print(
|
||||
"Extracting Voxforge data set into {} if not already present...".format(
|
||||
data_dir
|
||||
)
|
||||
)
|
||||
extracter = _parallel_extracter(
|
||||
data_dir, number_of_test, number_of_dev, len(tarfiles), AtomicCounter()
|
||||
)
|
||||
p.map(extracter, enumerate(tarfiles))
|
||||
|
||||
# Generate data set
|
||||
@ -156,42 +174,50 @@ def _download_and_preprocess_data(data_dir):
|
||||
train_files = _generate_dataset(data_dir, "train")
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
|
||||
|
||||
def _generate_dataset(data_dir, data_set):
|
||||
extracted_dir = path.join(data_dir, data_set)
|
||||
files = []
|
||||
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
||||
if path.isdir(path.join(promts_file[:-11],"wav")):
|
||||
with codecs.open(promts_file, 'r', 'utf-8') as f:
|
||||
for promts_file in glob(os.path.join(extracted_dir + "/*/etc/", "PROMPTS")):
|
||||
if path.isdir(os.path.join(promts_file[:-11], "wav")):
|
||||
with codecs.open(promts_file, "r", "utf-8") as f:
|
||||
for line in f:
|
||||
id = line.split(' ')[0].split('/')[-1]
|
||||
sentence = ' '.join(line.split(' ')[1:])
|
||||
sentence = re.sub("[^a-z']"," ",sentence.strip().lower())
|
||||
id = line.split(" ")[0].split("/")[-1]
|
||||
sentence = " ".join(line.split(" ")[1:])
|
||||
sentence = re.sub("[^a-z']", " ", sentence.strip().lower())
|
||||
transcript = ""
|
||||
for token in sentence.split(" "):
|
||||
word = token.strip()
|
||||
if word!="" and word!=" ":
|
||||
if word != "" and word != " ":
|
||||
transcript += word + " "
|
||||
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav")
|
||||
transcript = (
|
||||
unicodedata.normalize("NFKD", transcript.strip())
|
||||
.encode("ascii", "ignore")
|
||||
.decode("ascii", "ignore")
|
||||
)
|
||||
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
|
||||
if gfile.Exists(wav_file):
|
||||
wav_filesize = path.getsize(wav_file)
|
||||
# remove audios that are shorter than 0.5s and longer than 20s.
|
||||
# remove audios that are too short for transcript.
|
||||
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \
|
||||
wav_filesize/len(transcript)>1400:
|
||||
files.append((path.abspath(wav_file), wav_filesize, transcript))
|
||||
if (
|
||||
(wav_filesize / 32000) > 0.5
|
||||
and (wav_filesize / 32000) < 20
|
||||
and transcript != ""
|
||||
and wav_filesize / len(transcript) > 1400
|
||||
):
|
||||
files.append(
|
||||
(os.path.abspath(wav_file), wav_filesize, transcript)
|
||||
)
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
return pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_filesize", "transcript"]
|
||||
)
|
||||
|
||||
if __name__=="__main__":
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,15 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import sys
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
|
||||
def main():
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], 'rb') as fin:
|
||||
with tfv1.gfile.FastGFile(sys.argv[1], "rb") as fin:
|
||||
graph_def = tfv1.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
print('\n'.join(sorted(set(n.op for n in graph_def.node))))
|
||||
print("\n".join(sorted(set(n.op for n in graph_def.node))))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
47
bin/play.py
47
bin/play.py
@ -3,19 +3,13 @@
|
||||
Tool for playing samples from Sample Databases (SDB files) and DeepSpeech CSV files
|
||||
Use "python3 build_sdb.py -h" for help
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import random
|
||||
import sys
|
||||
|
||||
from util.sample_collections import samples_from_file, LabeledSample
|
||||
from util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.sample_collections import LabeledSample, samples_from_file
|
||||
|
||||
|
||||
def play_sample(samples, index):
|
||||
@ -24,7 +18,7 @@ def play_sample(samples, index):
|
||||
if CLI_ARGS.random:
|
||||
index = random.randint(0, len(samples))
|
||||
elif index >= len(samples):
|
||||
print('No sample with index {}'.format(CLI_ARGS.start))
|
||||
print("No sample with index {}".format(CLI_ARGS.start))
|
||||
sys.exit(1)
|
||||
sample = samples[index]
|
||||
print('Sample "{}"'.format(sample.sample_id))
|
||||
@ -50,13 +44,28 @@ def play_collection():
|
||||
|
||||
|
||||
def handle_args():
|
||||
parser = argparse.ArgumentParser(description='Tool for playing samples from Sample Databases (SDB files) '
|
||||
'and DeepSpeech CSV files')
|
||||
parser.add_argument('collection', help='Sample DB or CSV file to play samples from')
|
||||
parser.add_argument('--start', type=int, default=0,
|
||||
help='Sample index to start at (negative numbers are relative to the end of the collection)')
|
||||
parser.add_argument('--number', type=int, default=-1, help='Number of samples to play (-1 for endless)')
|
||||
parser.add_argument('--random', action='store_true', help='If samples should be played in random order')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for playing samples from Sample Databases (SDB files) "
|
||||
"and DeepSpeech CSV files"
|
||||
)
|
||||
parser.add_argument("collection", help="Sample DB or CSV file to play samples from")
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Sample index to start at (negative numbers are relative to the end of the collection)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of samples to play (-1 for endless)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random",
|
||||
action="store_true",
|
||||
help="If samples should be played in random order",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -70,5 +79,5 @@ if __name__ == "__main__":
|
||||
try:
|
||||
play_collection()
|
||||
except KeyboardInterrupt:
|
||||
print(' Stopped')
|
||||
print(" Stopped")
|
||||
sys.exit(0)
|
||||
|
@ -31,7 +31,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "########################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--load_train "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
@ -45,60 +45,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "##############################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
echo "#################################################################################"
|
||||
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "#################################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load 'last' \
|
||||
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
done
|
||||
|
||||
echo "#######################################################"
|
||||
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
|
||||
echo "##### while iterating over loading logic #####"
|
||||
echo "#######################################################"
|
||||
|
||||
for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "########################################################"
|
||||
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||
echo "########################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
|
||||
echo "##############################################################################"
|
||||
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "##############################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--load_train "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
@ -114,13 +61,20 @@ for LOAD in 'init' 'last' 'auto'; do
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load 'last' \
|
||||
--load_train 'last' \
|
||||
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
# Test transfer learning checkpoint
|
||||
python -u evaluate.py --noshow_progressbar \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100
|
||||
done
|
||||
|
@ -1,17 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
@ -48,6 +42,9 @@ def create_bundle(
|
||||
if use_utf8:
|
||||
serialized_alphabet = UTF8Alphabet().serialize()
|
||||
else:
|
||||
if not alphabet_path:
|
||||
print("No --alphabet path specified, can't continue.")
|
||||
sys.exit(1)
|
||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
||||
|
||||
alphabet = NativeAlphabet()
|
||||
|
@ -1,13 +1,17 @@
|
||||
C API Usage example
|
||||
===================
|
||||
|
||||
Examples are from `native_client/client.cc`.
|
||||
|
||||
Creating a model instance and loading model
|
||||
-------------------------------------------
|
||||
|
||||
.. literalinclude:: ../native_client/client.cc
|
||||
:language: c
|
||||
:linenos:
|
||||
:lines: 370-375,386-390
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: c_ref_model_start
|
||||
:end-before: sphinx-doc: c_ref_model_stop
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
@ -15,7 +19,9 @@ Performing inference
|
||||
.. literalinclude:: ../native_client/client.cc
|
||||
:language: c
|
||||
:linenos:
|
||||
:lines: 59-94
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: c_ref_inference_start
|
||||
:end-before: sphinx-doc: c_ref_inference_stop
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
@ -16,9 +16,9 @@ DeepSpeech Class
|
||||
:members:
|
||||
|
||||
DeepSpeechStream Class
|
||||
----------------
|
||||
----------------------
|
||||
|
||||
.. doxygenclass:: DeepSpeechClient::DeepSpeechStream
|
||||
.. doxygenclass:: DeepSpeechClient::Models::DeepSpeechStream
|
||||
:project: deepspeech-dotnet
|
||||
:members:
|
||||
|
||||
|
29
doc/DotNet-Examples.rst
Normal file
29
doc/DotNet-Examples.rst
Normal file
@ -0,0 +1,29 @@
|
||||
.Net API Usage example
|
||||
======================
|
||||
|
||||
Examples are from `native_client/dotnet/DeepSpeechConsole/Program.cs`.
|
||||
|
||||
Creating a model instance and loading model
|
||||
-------------------------------------------
|
||||
|
||||
.. literalinclude:: ../native_client/dotnet/DeepSpeechConsole/Program.cs
|
||||
:language: csharp
|
||||
:linenos:
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: csharp_ref_model_start
|
||||
:end-before: sphinx-doc: csharp_ref_model_stop
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
|
||||
.. literalinclude:: ../native_client/dotnet/DeepSpeechConsole/Program.cs
|
||||
:language: csharp
|
||||
:linenos:
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: csharp_ref_inference_start
|
||||
:end-before: sphinx-doc: csharp_ref_inference_stop
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
||||
See :download:`Full source code<../native_client/dotnet/DeepSpeechConsole/Program.cs>`.
|
@ -1,13 +1,17 @@
|
||||
Java API Usage example
|
||||
======================
|
||||
|
||||
Examples are from `native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java`.
|
||||
|
||||
Creating a model instance and loading model
|
||||
-------------------------------------------
|
||||
|
||||
.. literalinclude:: ../native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
|
||||
:language: java
|
||||
:linenos:
|
||||
:lines: 52
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: java_ref_model_start
|
||||
:end-before: sphinx-doc: java_ref_model_stop
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
@ -15,7 +19,9 @@ Performing inference
|
||||
.. literalinclude:: ../native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java
|
||||
:language: java
|
||||
:linenos:
|
||||
:lines: 101
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: java_ref_inference_start
|
||||
:end-before: sphinx-doc: java_ref_inference_stop
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
@ -1,6 +1,8 @@
|
||||
JavaScript (NodeJS / ElectronJS)
|
||||
================================
|
||||
|
||||
Support for TypeScript is :download:`provided in index.d.ts<../native_client/javascript/index.d.ts>`
|
||||
|
||||
Model
|
||||
-----
|
||||
|
||||
|
@ -1,23 +1,29 @@
|
||||
JavaScript API Usage example
|
||||
=============================
|
||||
|
||||
Examples are from `native_client/javascript/client.ts`.
|
||||
|
||||
Creating a model instance and loading model
|
||||
-------------------------------------------
|
||||
|
||||
.. literalinclude:: ../native_client/javascript/client.js
|
||||
.. literalinclude:: ../native_client/javascript/client.ts
|
||||
:language: javascript
|
||||
:linenos:
|
||||
:lines: 56,69
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: js_ref_model_start
|
||||
:end-before: sphinx-doc: js_ref_model_stop
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
|
||||
.. literalinclude:: ../native_client/javascript/client.js
|
||||
.. literalinclude:: ../native_client/javascript/client.ts
|
||||
:language: javascript
|
||||
:linenos:
|
||||
:lines: 122
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: js_ref_inference_start
|
||||
:end-before: sphinx-doc: js_ref_inference_stop
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
||||
See :download:`Full source code<../native_client/javascript/client.js>`.
|
||||
See :download:`Full source code<../native_client/javascript/client.ts>`.
|
||||
|
@ -1,13 +1,17 @@
|
||||
Python API Usage example
|
||||
========================
|
||||
|
||||
Examples are from `native_client/python/client.cc`.
|
||||
|
||||
Creating a model instance and loading model
|
||||
-------------------------------------------
|
||||
|
||||
.. literalinclude:: ../native_client/python/client.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 111,123
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: python_ref_model_start
|
||||
:end-before: sphinx-doc: python_ref_model_stop
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
@ -15,7 +19,9 @@ Performing inference
|
||||
.. literalinclude:: ../native_client/python/client.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 143-148
|
||||
:lineno-match:
|
||||
:start-after: sphinx-doc: python_ref_inference_start
|
||||
:end-before: sphinx-doc: python_ref_inference_stop
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/
|
||||
$ python3 -m venv $HOME/tmp/deepspeech-train-venv/
|
||||
|
||||
Once this command completes successfully, the environment will be ready to be activated.
|
||||
|
||||
@ -38,15 +38,16 @@ Each time you need to work with DeepSpeech, you have to *activate* this virtual
|
||||
|
||||
$ source $HOME/tmp/deepspeech-train-venv/bin/activate
|
||||
|
||||
Installing Python dependencies
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Installing DeepSpeech Training Code and its dependencies
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Install the required dependencies using ``pip3``\ :
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd DeepSpeech
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install --upgrade pip==20.0.2 wheel==0.34.2 setuptools==46.1.3
|
||||
pip3 install --upgrade --force-reinstall -e .
|
||||
|
||||
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
|
||||
|
||||
@ -54,14 +55,6 @@ The ``webrtcvad`` Python package might require you to ensure you have proper too
|
||||
|
||||
sudo apt-get install python3-dev
|
||||
|
||||
You'll also need to install the ``ds_ctcdecoder`` Python package. ``ds_ctcdecoder`` is required for decoding the outputs of the ``deepspeech`` acoustic model into text. You can use ``util/taskcluster.py`` with the ``--decoder`` flag to get a URL to a binary of the decoder package appropriate for your platform and Python version:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 install $(python3 util/taskcluster.py --decoder)
|
||||
|
||||
This command will download and install the ``ds_ctcdecoder`` package. You can override the platform with ``--arch`` if you want the package for ARM7 (\ ``--arch arm``\ ) or ARM64 (\ ``--arch arm64``\ ). If you prefer building the ``ds_ctcdecoder`` package from source, see the :github:`native_client README file <native_client/README.rst>`.
|
||||
|
||||
Recommendations
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
@ -70,7 +63,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 uninstall tensorflow
|
||||
pip3 install 'tensorflow-gpu==1.15.0'
|
||||
pip3 install 'tensorflow-gpu==1.15.2'
|
||||
|
||||
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.
|
||||
|
||||
@ -148,6 +141,18 @@ Feel also free to pass additional (or overriding) ``DeepSpeech.py`` parameters t
|
||||
|
||||
Each dataset has a corresponding importer script in ``bin/`` that can be used to download (if it's freely available) and preprocess the dataset. See ``bin/import_librivox.py`` for an example of how to import and preprocess a large dataset for training with DeepSpeech.
|
||||
|
||||
Some importers might require additional code to properly handled your locale-specific requirements. Such handling is dealt with ``--validate_label_locale`` flag that allows you to source out-of-tree Python script that defines a ``validate_label`` function. Please refer to ``util/importers.py`` for implementation example of that function.
|
||||
If you don't provide this argument, the default ``validate_label`` function will be used. This one is only intended for English language, so you might have consistency issues in your data for other languages.
|
||||
|
||||
For example, in order to use a custom validation function that disallows any sample with "a" in its transcript, and lower cases everything else, you could put the following code in a file called ``my_validation.py`` and then use ``--validate_label_locale my_validation.py``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def validate_label(label):
|
||||
if 'a' in label: # disallow labels with 'a'
|
||||
return None
|
||||
return label.lower() # lower case valid labels
|
||||
|
||||
If you've run the old importers (in ``util/importers/``\ ), they could have removed source files that are needed for the new importers to run. In that case, simply remove the extracted folders and let the importer extract and process the dataset from scratch, and things should work.
|
||||
|
||||
Training with automatic mixed precision
|
||||
|
@ -125,6 +125,8 @@ Please note that as of now, we support:
|
||||
- Node.JS versions 4 to 13.
|
||||
- Electron.JS versions 1.6 to 7.1
|
||||
|
||||
TypeScript support is also provided.
|
||||
|
||||
Alternatively, if you're using Linux and have a supported NVIDIA GPU, you can install the GPU specific package as follows:
|
||||
|
||||
.. code-block:: bash
|
||||
@ -133,7 +135,7 @@ Alternatively, if you're using Linux and have a supported NVIDIA GPU, you can in
|
||||
|
||||
See the `release notes <https://github.com/mozilla/DeepSpeech/releases>`_ to find which GPUs are supported. Please ensure you have the required `CUDA dependency <#cuda-dependency>`_.
|
||||
|
||||
See :github:`client.js <native_client/javascript/client.js>` for an example of how to use the bindings.
|
||||
See :github:`client.ts <native_client/javascript/client.ts>` for an example of how to use the bindings.
|
||||
|
||||
Using the Command-Line client
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -104,7 +104,7 @@ language = None
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store', 'node_modules']
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
|
@ -14,6 +14,8 @@ Welcome to DeepSpeech's documentation!
|
||||
|
||||
TRAINING
|
||||
|
||||
Decoder
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: DeepSpeech Model
|
||||
@ -52,17 +54,19 @@ Welcome to DeepSpeech's documentation!
|
||||
|
||||
C-Examples
|
||||
|
||||
NodeJS-Examples
|
||||
DotNet-Examples
|
||||
|
||||
Java-Examples
|
||||
|
||||
NodeJS-Examples
|
||||
|
||||
Python-Examples
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contributed examples
|
||||
|
||||
DotNet-contrib-examples.rst
|
||||
DotNet-contrib-examples
|
||||
|
||||
NodeJS-contrib-Examples
|
||||
|
||||
|
158
evaluate.py
Executable file → Normal file
158
evaluate.py
Executable file → Normal file
@ -2,155 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.feeding import create_dataset
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version
|
||||
from util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values, converting tokens to strings using ``alphabet``.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [[] for _ in range(sp_tuple[2][0])]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x,
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
seq_length=batch_x_len,
|
||||
dropout=no_dropout)
|
||||
|
||||
# Transpose to batch major and apply softmax for decoder
|
||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||
|
||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
bar = create_progressbar(prefix='Test epoch | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Test epoch...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
bar.update(step_count)
|
||||
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
print('Testing model on {}'.format(csv))
|
||||
samples.extend(run_test(init_op, dataset=csv))
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
try:
|
||||
from deepspeech_training import evaluate as ds_evaluate
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
ds_evaluate.run_script()
|
||||
|
@ -10,13 +10,12 @@ import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
from functools import partial
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.flags import create_flags
|
||||
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
|
||||
from deepspeech_training.util.flags import create_flags
|
||||
from functools import partial
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from six.moves import zip, range
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
|
@ -2,19 +2,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, print_function
|
||||
|
||||
import sys
|
||||
|
||||
import optuna
|
||||
import absl.app
|
||||
from ds_ctcdecoder import Scorer
|
||||
import optuna
|
||||
import sys
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from DeepSpeech import create_model
|
||||
from evaluate import evaluate
|
||||
from util.config import Config, initialize_globals
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error
|
||||
from util.evaluate_tools import wer_cer_batch
|
||||
from deepspeech_training.evaluate import evaluate
|
||||
from deepspeech_training.train import create_model
|
||||
from deepspeech_training.util.config import Config, initialize_globals
|
||||
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||
from deepspeech_training.util.logging import log_error
|
||||
from deepspeech_training.util.evaluate_tools import wer_cer_batch
|
||||
from ds_ctcdecoder import Scorer
|
||||
|
||||
|
||||
def character_based():
|
||||
@ -28,11 +27,23 @@ def objective(trial):
|
||||
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
|
||||
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
is_character_based = trial.study.user_attrs['is_character_based']
|
||||
|
||||
samples = []
|
||||
for step, test_file in enumerate(FLAGS.test_files.split(',')):
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
current_samples = evaluate([test_file], create_model)
|
||||
samples += current_samples
|
||||
|
||||
# Report intermediate objective value.
|
||||
wer, cer = wer_cer_batch(current_samples)
|
||||
trial.report(cer if is_character_based else wer, step)
|
||||
|
||||
# Handle pruning based on the intermediate value.
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
wer, cer = wer_cer_batch(samples)
|
||||
return cer if is_character_based else wer
|
||||
|
||||
|
@ -14,7 +14,10 @@ It is required to use our fork of TensorFlow since it includes fixes for common
|
||||
If you'd like to build the language bindings or the decoder package, you'll also need:
|
||||
|
||||
|
||||
* `SWIG >= 3.0.12 <http://www.swig.org/>`_. If you intend to build NodeJS / ElectronJS bindings you will need a patched version of SWIG. Please refer to the matching section below.
|
||||
* `SWIG >= 3.0.12 <http://www.swig.org/>`_.
|
||||
Unfortunately, NodeJS / ElectronJS after 10.x support on SWIG is a bit behind, and while there are pending patches proposed to upstream, it is not yet merged.
|
||||
The proper prebuilt patched version (covering linux, windows and macOS) of SWIG should get installed under `native_client/ <native_client/>`_ as soon as you build any bindings that requires it.
|
||||
|
||||
* `node-pre-gyp <https://github.com/mapbox/node-pre-gyp>`_ (for Node.JS bindings only)
|
||||
|
||||
Dependencies
|
||||
@ -108,10 +111,6 @@ The API mirrors the C++ API and is demonstrated in `client.py <python/client.py>
|
||||
Install NodeJS / ElectronJS bindings
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Unfortunately, JavaScript support on SWIG is a bit behind, and while there are pending patches proposed to upstream, it is not yet merged.
|
||||
You should be able to build from `our fork <https://github.com/lissyx/swig/tree/taskcluster>`_, and you can find pre-built binaries on `TaskCluster <https://community-tc.services.mozilla.com/tasks/index/project.deepspeech.swig>`_ (please look for swig fork sha1).
|
||||
Extract the `ds-swig.tar.gz` to some place in your `$HOME`, then update `$PATH` accordingly. You might need to symlink `ds-swig` as `swig`, and you will have to `export SWIG_LIB=<path/to/swig/share>` so that it contains path to `share/swig/<VERSION>/`.
|
||||
|
||||
After following the above build and installation instructions, the Node.JS bindings can be built:
|
||||
|
||||
.. code-block::
|
||||
|
@ -162,6 +162,7 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
|
||||
|
||||
clock_t ds_start_time = clock();
|
||||
|
||||
// sphinx-doc: c_ref_inference_start
|
||||
if (extended_output) {
|
||||
Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1);
|
||||
res.string = CandidateTranscriptToString(&result->transcripts[0]);
|
||||
@ -198,6 +199,7 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
|
||||
} else {
|
||||
res.string = DS_SpeechToText(aCtx, aBuffer, aBufferSize);
|
||||
}
|
||||
// sphinx-doc: c_ref_inference_stop
|
||||
|
||||
clock_t ds_end_infer = clock();
|
||||
|
||||
@ -393,6 +395,7 @@ main(int argc, char **argv)
|
||||
|
||||
// Initialise DeepSpeech
|
||||
ModelState* ctx;
|
||||
// sphinx-doc: c_ref_model_start
|
||||
int status = DS_CreateModel(model, &ctx);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not create model.\n");
|
||||
@ -421,6 +424,7 @@ main(int argc, char **argv)
|
||||
}
|
||||
}
|
||||
}
|
||||
// sphinx-doc: c_ref_model_stop
|
||||
|
||||
#ifndef NO_SOX
|
||||
// Initialise SOX
|
||||
|
@ -43,16 +43,16 @@ workspace_status.cc:
|
||||
|
||||
# Enforce PATH here because swig calls from build_ext looses track of some
|
||||
# variables over several runs
|
||||
bindings: clean-keep-third-party workspace_status.cc
|
||||
bindings: clean-keep-third-party workspace_status.cc ds-swig
|
||||
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
|
||||
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
find temp_build -type f -name "*.o" -delete
|
||||
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
rm -rf temp_build
|
||||
|
||||
bindings-debug: clean-keep-third-party workspace_status.cc
|
||||
bindings-debug: clean-keep-third-party workspace_status.cc ds-swig
|
||||
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
|
||||
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
$(GENERATE_DEBUG_SYMS)
|
||||
find temp_build -type f -name "*.o" -delete
|
||||
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
|
@ -12,7 +12,11 @@ TOOL_LD := ld
|
||||
TOOL_LDD := ldd
|
||||
TOOL_LIBEXE :=
|
||||
|
||||
DEEPSPEECH_BIN := deepspeech
|
||||
ifeq ($(findstring _NT,$(OS)),_NT)
|
||||
PLATFORM_EXE_SUFFIX := .exe
|
||||
endif
|
||||
|
||||
DEEPSPEECH_BIN := deepspeech$(PLATFORM_EXE_SUFFIX)
|
||||
CFLAGS_DEEPSPEECH := -std=c++11 -o $(DEEPSPEECH_BIN)
|
||||
LINK_DEEPSPEECH := -ldeepspeech
|
||||
LINK_PATH_DEEPSPEECH := -L${TFDIR}/bazel-bin/native_client
|
||||
@ -36,7 +40,6 @@ endif
|
||||
endif
|
||||
|
||||
ifeq ($(TARGET),host-win)
|
||||
DEEPSPEECH_BIN := deepspeech.exe
|
||||
TOOLCHAIN := '$(VCINSTALLDIR)\bin\amd64\'
|
||||
TOOL_CC := cl.exe
|
||||
TOOL_CXX := cl.exe
|
||||
@ -170,3 +173,36 @@ define copy_missing_libs
|
||||
done; \
|
||||
fi;
|
||||
endef
|
||||
|
||||
SWIG_DIST_URL ?=
|
||||
ifeq ($(findstring Linux,$(OS)),Linux)
|
||||
SWIG_DIST_URL := "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.linux.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz"
|
||||
else ifeq ($(findstring Darwin,$(OS)),Darwin)
|
||||
SWIG_DIST_URL := "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.darwin.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz"
|
||||
else ifeq ($(findstring _NT,$(OS)),_NT)
|
||||
SWIG_DIST_URL := "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.win.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz"
|
||||
else
|
||||
$(error There is no prebuilt SWIG available for your platform. Please produce one and set SWIG_DIST_URL.)
|
||||
endif
|
||||
|
||||
# Should point to native_client/ subdir by default
|
||||
SWIG_ROOT ?= $(abspath $(shell dirname "$(lastword $(MAKEFILE_LIST))"))/ds-swig
|
||||
ifeq ($(findstring _NT,$(OS)),_NT)
|
||||
SWIG_ROOT ?= $(shell cygpath -u "$(SWIG_ROOT)")
|
||||
endif
|
||||
SWIG_LIB ?= $(SWIG_ROOT)/share/swig/4.0.2/
|
||||
|
||||
SWIG_BIN := swig$(PLATFORM_EXE_SUFFIX)
|
||||
DS_SWIG_BIN := ds-swig$(PLATFORM_EXE_SUFFIX)
|
||||
DS_SWIG_BIN_PATH := $(SWIG_ROOT)/bin
|
||||
|
||||
DS_SWIG_ENV := SWIG_LIB="$(SWIG_LIB)" PATH="${PATH}:$(DS_SWIG_BIN_PATH)"
|
||||
|
||||
$(DS_SWIG_BIN_PATH)/swig:
|
||||
mkdir -p $(SWIG_ROOT)
|
||||
wget -O - "$(SWIG_DIST_URL)" | tar -C $(SWIG_ROOT) -zxf -
|
||||
ln -s $(DS_SWIG_BIN) $(DS_SWIG_BIN_PATH)/$(SWIG_BIN)
|
||||
|
||||
ds-swig: $(DS_SWIG_BIN_PATH)/swig
|
||||
$(DS_SWIG_ENV) swig -version
|
||||
$(DS_SWIG_ENV) swig -swiglib
|
||||
|
@ -51,8 +51,10 @@ namespace CSharpExamples
|
||||
{
|
||||
Console.WriteLine("Loading model...");
|
||||
stopwatch.Start();
|
||||
// sphinx-doc: csharp_ref_model_start
|
||||
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
|
||||
{
|
||||
// sphinx-doc: csharp_ref_model_stop
|
||||
stopwatch.Stop();
|
||||
|
||||
Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms");
|
||||
@ -72,6 +74,7 @@ namespace CSharpExamples
|
||||
stopwatch.Start();
|
||||
|
||||
string speechResult;
|
||||
// sphinx-doc: csharp_ref_inference_start
|
||||
if (extended)
|
||||
{
|
||||
Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer,
|
||||
@ -83,6 +86,7 @@ namespace CSharpExamples
|
||||
speechResult = sttClient.SpeechToText(waveBuffer.ShortBuffer,
|
||||
Convert.ToUInt32(waveBuffer.MaxSize / 2));
|
||||
}
|
||||
// sphinx-doc: csharp_ref_inference_stop
|
||||
|
||||
stopwatch.Stop();
|
||||
|
||||
@ -99,4 +103,4 @@ namespace CSharpExamples
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,5 +27,5 @@ maven-bundle: apk
|
||||
$(GRADLE) uploadArchives
|
||||
$(GRADLE) zipMavenArtifacts
|
||||
|
||||
bindings: clean
|
||||
swig -c++ -java -package org.mozilla.deepspeech.libdeepspeech -outdir libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/ -o jni/deepspeech_wrap.cpp jni/deepspeech.i
|
||||
bindings: clean ds-swig
|
||||
$(DS_SWIG_ENV) swig -c++ -java -package org.mozilla.deepspeech.libdeepspeech -outdir libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/ -o jni/deepspeech_wrap.cpp jni/deepspeech.i
|
||||
|
@ -49,8 +49,10 @@ public class DeepSpeechActivity extends AppCompatActivity {
|
||||
private void newModel(String tfliteModel) {
|
||||
this._tfliteStatus.setText("Creating model");
|
||||
if (this._m == null) {
|
||||
// sphinx-doc: java_ref_model_start
|
||||
this._m = new DeepSpeechModel(tfliteModel);
|
||||
this._m.setBeamWidth(BEAM_WIDTH);
|
||||
// sphinx-doc: java_ref_model_stop
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,7 +100,9 @@ public class DeepSpeechActivity extends AppCompatActivity {
|
||||
|
||||
long inferenceStartTime = System.currentTimeMillis();
|
||||
|
||||
// sphinx-doc: java_ref_inference_start
|
||||
String decoded = this._m.stt(shorts, shorts.length);
|
||||
// sphinx-doc: java_ref_inference_stop
|
||||
|
||||
inferenceExecTime = System.currentTimeMillis() - inferenceStartTime;
|
||||
|
||||
|
@ -1,26 +1,37 @@
|
||||
NODE_BUILD_TOOL ?= node-pre-gyp
|
||||
NODE_ABI_TARGET ?=
|
||||
NODE_BUILD_VERBOSE ?= --verbose
|
||||
NPM_TOOL ?= npm
|
||||
PROJECT_NAME ?= deepspeech
|
||||
PROJECT_VERSION ?= $(shell cat ../../VERSION | tr -d '\n')
|
||||
NPM_ROOT ?= $(shell npm root)
|
||||
|
||||
NODE_MODULES_BIN ?= $(NPM_ROOT)/.bin/
|
||||
ifeq ($(findstring _NT,$(OS)),_NT)
|
||||
# On Windows, we seem to need both in PATH for node-pre-gyp as well as tsc
|
||||
# they do not get installed the same way.
|
||||
NODE_MODULES_BIN := $(shell cygpath -u $(NPM_ROOT)/.bin/):$(shell cygpath -u `dirname "$(NPM_ROOT)"`)
|
||||
endif
|
||||
|
||||
include ../definitions.mk
|
||||
|
||||
ifeq ($(TARGET),host-win)
|
||||
ifeq ($(findstring _NT,$(OS)),_NT)
|
||||
LIBS := '$(shell cygpath -w $(subst .lib,,$(LIBS)))'
|
||||
endif
|
||||
|
||||
.PHONY: npm-dev
|
||||
|
||||
default: build
|
||||
|
||||
clean:
|
||||
rm -f deepspeech_wrap.cxx package.json
|
||||
rm -f deepspeech_wrap.cxx package.json package-lock.json
|
||||
rm -rf ./build/
|
||||
|
||||
clean-npm-pack:
|
||||
rm -fr ./node_modules/
|
||||
rm -fr ./deepspeech-*.tgz
|
||||
|
||||
really-clean: clean clean-npm-pack
|
||||
rm -fr ./node_modules/
|
||||
rm -fr ./lib/
|
||||
|
||||
package.json: package.json.in
|
||||
@ -29,22 +40,27 @@ package.json: package.json.in
|
||||
-e 's/$$(PROJECT_VERSION)/$(PROJECT_VERSION)/' \
|
||||
package.json.in > package.json && cat package.json
|
||||
|
||||
configure: deepspeech_wrap.cxx package.json
|
||||
$(NODE_BUILD_TOOL) configure $(NODE_BUILD_VERBOSE)
|
||||
npm-dev: package.json
|
||||
ifeq ($(findstring _NT,$(OS)),_NT)
|
||||
# node-gyp@5.x behaves erratically with VS2015 and MSBuild.exe detection
|
||||
$(NPM_TOOL) install node-gyp@4.x
|
||||
endif
|
||||
$(NPM_TOOL) install --prefix=$(NPM_ROOT)/../ --ignore-scripts --force --verbose --production=false .
|
||||
|
||||
configure: deepspeech_wrap.cxx package.json npm-dev
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" $(NODE_BUILD_TOOL) configure $(NODE_BUILD_VERBOSE)
|
||||
|
||||
build: configure deepspeech_wrap.cxx
|
||||
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" CXXFLAGS="$(CXXFLAGS)" LDFLAGS="$(RPATH_NODEJS) $(LDFLAGS)" LIBS=$(LIBS) $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) rebuild $(NODE_BUILD_VERBOSE)
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" CXXFLAGS="$(CXXFLAGS)" LDFLAGS="$(RPATH_NODEJS) $(LDFLAGS)" LIBS=$(LIBS) $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) --no-color rebuild $(NODE_BUILD_VERBOSE)
|
||||
|
||||
copy-deps: build
|
||||
$(call copy_missing_libs,lib/binding/*/*/*/deepspeech.node,lib/binding/*/*/)
|
||||
|
||||
node-wrapper: copy-deps build
|
||||
$(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) package $(NODE_BUILD_VERBOSE)
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DIST_URL) --no-color package $(NODE_BUILD_VERBOSE)
|
||||
|
||||
npm-pack: clean package.json index.js
|
||||
npm install node-pre-gyp@0.14.x
|
||||
npm pack $(NODE_BUILD_VERBOSE)
|
||||
npm-pack: clean package.json index.js npm-dev
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" tsc && $(NPM_TOOL) pack $(NODE_BUILD_VERBOSE)
|
||||
|
||||
deepspeech_wrap.cxx: deepspeech.i
|
||||
swig -version
|
||||
swig -c++ -javascript -node deepspeech.i
|
||||
deepspeech_wrap.cxx: deepspeech.i ds-swig
|
||||
$(DS_SWIG_ENV) swig -c++ -javascript -node deepspeech.i
|
||||
|
@ -1 +1,18 @@
|
||||
Full project description and documentation on GitHub: [https://github.com/mozilla/DeepSpeech](https://github.com/mozilla/DeepSpeech).
|
||||
|
||||
## Generating TypeScript Type Definitions
|
||||
|
||||
You can generate the TypeScript type declaration file using `dts-gen`.
|
||||
This requires a compiled/installed version of the DeepSpeech NodeJS client.
|
||||
|
||||
Upon API change, it is required to generate a new `index.d.ts` type declaration
|
||||
file, you have to run:
|
||||
|
||||
```sh
|
||||
npm install -g dts-gen
|
||||
dts-gen --module deepspeech --file index.d.ts
|
||||
```
|
||||
|
||||
### Example usage
|
||||
|
||||
See `client.ts`
|
||||
|
@ -1,48 +1,42 @@
|
||||
#!/usr/bin/env node
|
||||
'use strict';
|
||||
|
||||
const Fs = require('fs');
|
||||
const Sox = require('sox-stream');
|
||||
const Ds = require('./index.js');
|
||||
const argparse = require('argparse');
|
||||
const MemoryStream = require('memory-stream');
|
||||
const Wav = require('node-wav');
|
||||
const Duplex = require('stream').Duplex;
|
||||
const util = require('util');
|
||||
// This is required for process.versions.electron below
|
||||
/// <reference types="electron" />
|
||||
|
||||
var VersionAction = function VersionAction(options) {
|
||||
options = options || {};
|
||||
options.nargs = 0;
|
||||
argparse.Action.call(this, options);
|
||||
}
|
||||
util.inherits(VersionAction, argparse.Action);
|
||||
import Ds from "./index";
|
||||
import * as Fs from "fs";
|
||||
import Sox from "sox-stream";
|
||||
import * as argparse from "argparse";
|
||||
|
||||
VersionAction.prototype.call = function(parser) {
|
||||
console.log('DeepSpeech ' + Ds.Version());
|
||||
let runtime = 'Node';
|
||||
if (process.versions.electron) {
|
||||
runtime = 'Electron';
|
||||
const MemoryStream = require("memory-stream");
|
||||
const Wav = require("node-wav");
|
||||
const Duplex = require("stream").Duplex;
|
||||
|
||||
class VersionAction extends argparse.Action {
|
||||
call(parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: string | string[], optionString: string | null) {
|
||||
console.log('DeepSpeech ' + Ds.Version());
|
||||
let runtime = 'Node';
|
||||
if (process.versions.electron) {
|
||||
runtime = 'Electron';
|
||||
}
|
||||
console.error('Runtime: ' + runtime);
|
||||
process.exit(0);
|
||||
}
|
||||
console.error('Runtime: ' + runtime);
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'});
|
||||
let parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'});
|
||||
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
|
||||
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
|
||||
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
|
||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
|
||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
|
||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
|
||||
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
|
||||
parser.addArgument(['--version'], {action: VersionAction, nargs: 0, help: 'Print version and exits'});
|
||||
parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'});
|
||||
var args = parser.parseArgs();
|
||||
let args = parser.parseArgs();
|
||||
|
||||
function totalTime(hrtimeValue) {
|
||||
function totalTime(hrtimeValue: number[]): string {
|
||||
return (hrtimeValue[0] + hrtimeValue[1] / 1000000000).toPrecision(4);
|
||||
}
|
||||
|
||||
function candidateTranscriptToString(transcript) {
|
||||
function candidateTranscriptToString(transcript: Ds.CandidateTranscript): string {
|
||||
var retval = ""
|
||||
for (var i = 0; i < transcript.tokens.length; ++i) {
|
||||
retval += transcript.tokens[i].text;
|
||||
@ -50,17 +44,19 @@ function candidateTranscriptToString(transcript) {
|
||||
return retval;
|
||||
}
|
||||
|
||||
// sphinx-doc: js_ref_model_start
|
||||
console.error('Loading model from file %s', args['model']);
|
||||
const model_load_start = process.hrtime();
|
||||
var model = new Ds.Model(args['model']);
|
||||
let model = new Ds.Model(args['model']);
|
||||
const model_load_end = process.hrtime(model_load_start);
|
||||
console.error('Loaded model in %ds.', totalTime(model_load_end));
|
||||
|
||||
if (args['beam_width']) {
|
||||
model.setBeamWidth(args['beam_width']);
|
||||
}
|
||||
// sphinx-doc: js_ref_model_stop
|
||||
|
||||
var desired_sample_rate = model.sampleRate();
|
||||
let desired_sample_rate = model.sampleRate();
|
||||
|
||||
if (args['scorer']) {
|
||||
console.error('Loading scorer from file %s', args['scorer']);
|
||||
@ -78,23 +74,24 @@ const buffer = Fs.readFileSync(args['audio']);
|
||||
const result = Wav.decode(buffer);
|
||||
|
||||
if (result.sampleRate < desired_sample_rate) {
|
||||
console.error('Warning: original sample rate (' + result.sampleRate + ') ' +
|
||||
'is lower than ' + desired_sample_rate + 'Hz. ' +
|
||||
'Up-sampling might produce erratic speech recognition.');
|
||||
console.error(`Warning: original sample rate ( ${result.sampleRate})` +
|
||||
`is lower than ${desired_sample_rate} Hz. ` +
|
||||
`Up-sampling might produce erratic speech recognition.`);
|
||||
}
|
||||
|
||||
function bufferToStream(buffer) {
|
||||
function bufferToStream(buffer: Buffer) {
|
||||
var stream = new Duplex();
|
||||
stream.push(buffer);
|
||||
stream.push(null);
|
||||
return stream;
|
||||
}
|
||||
|
||||
var audioStream = new MemoryStream();
|
||||
let audioStream = new MemoryStream();
|
||||
bufferToStream(buffer).
|
||||
pipe(Sox({
|
||||
global: {
|
||||
'no-dither': true,
|
||||
'replay-gain': 'off',
|
||||
},
|
||||
output: {
|
||||
bits: 16,
|
||||
@ -115,6 +112,7 @@ audioStream.on('finish', () => {
|
||||
console.error('Running inference.');
|
||||
const audioLength = (audioBuffer.length / 2) * (1 / desired_sample_rate);
|
||||
|
||||
// sphinx-doc: js_ref_inference_start
|
||||
if (args['extended']) {
|
||||
let metadata = model.sttWithMetadata(audioBuffer, 1);
|
||||
console.log(candidateTranscriptToString(metadata.transcripts[0]));
|
||||
@ -122,6 +120,7 @@ audioStream.on('finish', () => {
|
||||
} else {
|
||||
console.log(model.stt(audioBuffer));
|
||||
}
|
||||
// sphinx-doc: js_ref_inference_stop
|
||||
const inference_stop = process.hrtime(inference_start);
|
||||
console.error('Inference took %ds for %ds audio file.', totalTime(inference_stop), audioLength.toPrecision(4));
|
||||
Ds.FreeModel(model);
|
196
native_client/javascript/index.d.ts
vendored
Normal file
196
native_client/javascript/index.d.ts
vendored
Normal file
@ -0,0 +1,196 @@
|
||||
/**
|
||||
* Stores text of an individual token, along with its timing information
|
||||
*/
|
||||
export interface TokenMetadata {
|
||||
text: string;
|
||||
timestep: number;
|
||||
start_time: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* A single transcript computed by the model, including a confidence value and
|
||||
* the metadata for its constituent tokens.
|
||||
*/
|
||||
export interface CandidateTranscript {
|
||||
tokens: TokenMetadata[];
|
||||
confidence: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* An array of CandidateTranscript objects computed by the model.
|
||||
*/
|
||||
export interface Metadata {
|
||||
transcripts: CandidateTranscript[];
|
||||
}
|
||||
|
||||
/**
|
||||
* An object providing an interface to a trained DeepSpeech model.
|
||||
*
|
||||
* @param aModelPath The path to the frozen model graph.
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
export class Model {
|
||||
constructor(aModelPath: string)
|
||||
|
||||
/**
|
||||
* Get beam width value used by the model. If :js:func:Model.setBeamWidth was
|
||||
* not called before, will return the default value loaded from the model file.
|
||||
*
|
||||
* @return Beam width value used by the model.
|
||||
*/
|
||||
beamWidth(): number;
|
||||
|
||||
/**
|
||||
* Set beam width value used by the model.
|
||||
*
|
||||
* @param The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
setBeamWidth(aBeamWidth: number): number;
|
||||
|
||||
/**
|
||||
* Return the sample rate expected by the model.
|
||||
*
|
||||
* @return Sample rate.
|
||||
*/
|
||||
sampleRate(): number;
|
||||
|
||||
/**
|
||||
* Enable decoding using an external scorer.
|
||||
*
|
||||
* @param aScorerPath The path to the external scorer file.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
enableExternalScorer(aScorerPath: string): number;
|
||||
|
||||
/**
|
||||
* Disable decoding using an external scorer.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
disableExternalScorer(): number;
|
||||
|
||||
/**
|
||||
* Set hyperparameters alpha and beta of the external scorer.
|
||||
*
|
||||
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model weight.
|
||||
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion weight.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
setScorerAlphaBeta(aLMAlpha: number, aLMBeta: number): number;
|
||||
|
||||
/**
|
||||
* Use the DeepSpeech model to perform Speech-To-Text.
|
||||
*
|
||||
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
*
|
||||
* @return The STT result. Returns undefined on error.
|
||||
*/
|
||||
stt(aBuffer: object): string;
|
||||
|
||||
/**
|
||||
* Use the DeepSpeech model to perform Speech-To-Text and output metadata
|
||||
* about the results.
|
||||
*
|
||||
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
* @param aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this.
|
||||
* Default value is 1 if not specified.
|
||||
*
|
||||
* @return :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
|
||||
* The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
|
||||
*/
|
||||
sttWithMetadata(aBuffer: object, aNumResults: number): Metadata;
|
||||
|
||||
/**
|
||||
* Create a new streaming inference state. One can then call :js:func:`Stream.feedAudioContent` and :js:func:`Stream.finishStream` on the returned stream object.
|
||||
*
|
||||
* @return a :js:func:`Stream` object that represents the streaming state.
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
createStream(): object;
|
||||
}
|
||||
|
||||
/**
|
||||
* @class
|
||||
* Provides an interface to a DeepSpeech stream. The constructor cannot be called
|
||||
* directly, use :js:func:`Model.createStream`.
|
||||
*/
|
||||
declare class Stream {
|
||||
/**
|
||||
* Feed audio samples to an ongoing streaming inference.
|
||||
*
|
||||
* @param aBuffer An array of 16-bit, mono raw audio samples at the
|
||||
* appropriate sample rate (matching what the model was trained on).
|
||||
*/
|
||||
feedAudioContent(aBuffer: object): void;
|
||||
|
||||
/**
|
||||
* Compute the intermediate decoding of an ongoing streaming inference.
|
||||
*
|
||||
* @return The STT intermediate result.
|
||||
*/
|
||||
intermediateDecode(aSctx: object): string;
|
||||
|
||||
/**
|
||||
* Compute the intermediate decoding of an ongoing streaming inference, return results including metadata.
|
||||
*
|
||||
* @param aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
|
||||
*
|
||||
* @return :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
|
||||
*/
|
||||
intermediateDecodeWithMetadata (aNumResults: number): Metadata;
|
||||
|
||||
/**
|
||||
* Compute the final decoding of an ongoing streaming inference and return the result. Signals the end of an ongoing streaming inference.
|
||||
*
|
||||
* @return The STT result.
|
||||
*
|
||||
* This method will free the stream, it must not be used after this method is called.
|
||||
*/
|
||||
finishStream(): string;
|
||||
|
||||
/**
|
||||
* Compute the final decoding of an ongoing streaming inference and return the results including metadata. Signals the end of an ongoing streaming inference.
|
||||
*
|
||||
* @param aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified.
|
||||
*
|
||||
* @return Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`.
|
||||
*
|
||||
* This method will free the stream, it must not be used after this method is called.
|
||||
*/
|
||||
finishStreamWithMetadata(aNumResults: number): Metadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* Frees associated resources and destroys model object.
|
||||
*
|
||||
* @param model A model pointer returned by :js:func:`Model`
|
||||
*
|
||||
*/
|
||||
export function FreeModel(model: Model): void;
|
||||
|
||||
/**
|
||||
* Free memory allocated for metadata information.
|
||||
*
|
||||
* @param metadata Object containing metadata as returned by :js:func:`Model.sttWithMetadata` or :js:func:`Model.finishStreamWithMetadata`
|
||||
*/
|
||||
export function FreeMetadata(metadata: Metadata): void;
|
||||
|
||||
/**
|
||||
* Destroy a streaming state without decoding the computed logits. This
|
||||
* can be used if you no longer need the result of an ongoing streaming
|
||||
* inference and don't want to perform a costly decode operation.
|
||||
*
|
||||
* @param stream A streaming state pointer returned by :js:func:`Model.createStream`.
|
||||
*/
|
||||
export function FreeStream(stream: object): void;
|
||||
|
||||
/**
|
||||
* Print version of this library and of the linked TensorFlow library on standard output.
|
||||
*/
|
||||
export function Version(): void;
|
@ -2,7 +2,8 @@
|
||||
"name" : "$(PROJECT_NAME)",
|
||||
"version" : "$(PROJECT_VERSION)",
|
||||
"description" : "DeepSpeech NodeJS bindings",
|
||||
"main" : "./index",
|
||||
"main" : "./index.js",
|
||||
"types": "./index.d.ts",
|
||||
"bin": {
|
||||
"deepspeech": "./client.js"
|
||||
},
|
||||
@ -13,6 +14,7 @@
|
||||
"README.md",
|
||||
"client.js",
|
||||
"index.js",
|
||||
"index.d.ts",
|
||||
"lib/*"
|
||||
],
|
||||
"bugs": {
|
||||
@ -37,6 +39,11 @@
|
||||
"node-wav": "0.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"electron": "^1.7.9",
|
||||
"node-gyp": "4.x - 5.x",
|
||||
"typescript": "3.6.x",
|
||||
"@types/argparse": "1.0.x",
|
||||
"@types/node": "13.9.x"
|
||||
},
|
||||
"scripts": {
|
||||
"test": "node index.js"
|
||||
|
18
native_client/javascript/tsconfig.json
Normal file
18
native_client/javascript/tsconfig.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"baseUrl": ".",
|
||||
"target": "es6",
|
||||
"module": "commonjs",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"noImplicitAny": true,
|
||||
"noImplicitThis": true,
|
||||
"strictFunctionTypes": true,
|
||||
"strictNullChecks": true,
|
||||
"forceConsistentCasingInFileNames": true
|
||||
},
|
||||
"files": [
|
||||
"index.d.ts",
|
||||
"client.ts"
|
||||
]
|
||||
}
|
@ -8,9 +8,9 @@ bindings-clean:
|
||||
|
||||
# Enforce PATH here because swig calls from build_ext looses track of some
|
||||
# variables over several runs
|
||||
bindings-build:
|
||||
bindings-build: ds-swig
|
||||
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
|
||||
PATH=$(TOOLCHAIN):$$PATH AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python ./setup.py build_ext $(PYTHON_PLATFORM_NAME)
|
||||
PATH=$(TOOLCHAIN):$(DS_SWIG_BIN_PATH):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python ./setup.py build_ext $(PYTHON_PLATFORM_NAME)
|
||||
|
||||
MANIFEST.in: bindings-build
|
||||
> $@
|
||||
|
@ -111,7 +111,9 @@ def main():
|
||||
|
||||
print('Loading model from file {}'.format(args.model), file=sys.stderr)
|
||||
model_load_start = timer()
|
||||
# sphinx-doc: python_ref_model_start
|
||||
ds = Model(args.model)
|
||||
# sphinx-doc: python_ref_model_stop
|
||||
model_load_end = timer() - model_load_start
|
||||
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
||||
|
||||
@ -131,24 +133,26 @@ def main():
|
||||
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
|
||||
|
||||
fin = wave.open(args.audio, 'rb')
|
||||
fs = fin.getframerate()
|
||||
if fs != desired_sample_rate:
|
||||
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs, desired_sample_rate), file=sys.stderr)
|
||||
fs, audio = convert_samplerate(args.audio, desired_sample_rate)
|
||||
fs_orig = fin.getframerate()
|
||||
if fs_orig != desired_sample_rate:
|
||||
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs_orig, desired_sample_rate), file=sys.stderr)
|
||||
fs_new, audio = convert_samplerate(args.audio, desired_sample_rate)
|
||||
else:
|
||||
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
|
||||
audio_length = fin.getnframes() * (1/fs)
|
||||
audio_length = fin.getnframes() * (1/fs_orig)
|
||||
fin.close()
|
||||
|
||||
print('Running inference.', file=sys.stderr)
|
||||
inference_start = timer()
|
||||
# sphinx-doc: python_ref_inference_start
|
||||
if args.extended:
|
||||
print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
|
||||
elif args.json:
|
||||
print(metadata_json_output(ds.sttWithMetadata(audio, 3)))
|
||||
else:
|
||||
print(ds.stt(audio))
|
||||
# sphinx-doc: python_ref_inference_stop
|
||||
inference_end = timer() - inference_start
|
||||
print('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length), file=sys.stderr)
|
||||
|
||||
|
@ -1,24 +0,0 @@
|
||||
# Main training requirements
|
||||
tensorflow == 1.15.2
|
||||
numpy == 1.18.1
|
||||
progressbar2
|
||||
six
|
||||
pyxdg
|
||||
attrdict
|
||||
absl-py
|
||||
semver
|
||||
opuslib == 2.0.0
|
||||
|
||||
# Requirements for building native_client files
|
||||
setuptools
|
||||
|
||||
# Requirements for importers
|
||||
sox
|
||||
bs4
|
||||
pandas
|
||||
requests
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
# Requirements for optimizer
|
||||
optuna
|
125
setup.py
Normal file
125
setup.py
Normal file
@ -0,0 +1,125 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import parse_version
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_decoder_pkg_url(version, artifacts_root=None):
|
||||
is_arm = 'arm' in platform.machine()
|
||||
is_mac = 'darwin' in sys.platform
|
||||
is_64bit = sys.maxsize > (2**31 - 1)
|
||||
|
||||
if is_arm:
|
||||
tc_arch = 'arm64-ctc' if is_64bit else 'arm-ctc'
|
||||
elif is_mac:
|
||||
tc_arch = 'osx-ctc'
|
||||
else:
|
||||
tc_arch = 'cpu-ctc'
|
||||
|
||||
ds_version = parse_version(version)
|
||||
branch = "v{}".format(version)
|
||||
|
||||
plat = platform.system().lower()
|
||||
arch = platform.machine()
|
||||
|
||||
if plat == 'linux' and arch == 'x86_64':
|
||||
plat = 'manylinux1'
|
||||
|
||||
if plat == 'darwin':
|
||||
plat = 'macosx_10_10'
|
||||
|
||||
is_ucs2 = sys.maxunicode < 0x10ffff
|
||||
m_or_mu = 'mu' if is_ucs2 else 'm'
|
||||
|
||||
pyver = ''.join(str(i) for i in sys.version_info[0:2])
|
||||
|
||||
if not artifacts_root:
|
||||
artifacts_root = 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.{branch_name}.{tc_arch_string}/artifacts/public'.format(
|
||||
branch_name=branch,
|
||||
tc_arch_string=tc_arch)
|
||||
|
||||
return 'ds_ctcdecoder @ {artifacts_root}/ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl'.format(
|
||||
artifacts_root=artifacts_root,
|
||||
ds_version=ds_version,
|
||||
pyver=pyver,
|
||||
m_or_mu=m_or_mu,
|
||||
platform=plat,
|
||||
arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
version_file = Path(__file__).parent / 'VERSION'
|
||||
with open(str(version_file)) as fin:
|
||||
version = fin.read().strip()
|
||||
|
||||
decoder_pkg_url = get_decoder_pkg_url(version)
|
||||
|
||||
install_requires_base = [
|
||||
'tensorflow == 1.15.2',
|
||||
'numpy == 1.18.1',
|
||||
'progressbar2',
|
||||
'six',
|
||||
'pyxdg',
|
||||
'attrdict',
|
||||
'absl-py',
|
||||
'semver',
|
||||
'opuslib == 2.0.0',
|
||||
'optuna',
|
||||
'sox',
|
||||
'bs4',
|
||||
'pandas',
|
||||
'requests',
|
||||
'librosa',
|
||||
'soundfile',
|
||||
]
|
||||
|
||||
# Due to pip craziness environment variables are the only consistent way to
|
||||
# get options into this script when doing `pip install`.
|
||||
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
|
||||
if tc_decoder_artifacts_root:
|
||||
# We're running inside the TaskCluster environment, override the decoder
|
||||
# package URL with the one we just built.
|
||||
decoder_pkg_url = get_decoder_pkg_url(version, tc_decoder_artifacts_root)
|
||||
install_requires = install_requires_base + [decoder_pkg_url]
|
||||
elif os.environ.get('DS_NODECODER', ''):
|
||||
install_requires = install_requires_base
|
||||
else:
|
||||
install_requires = install_requires_base + [decoder_pkg_url]
|
||||
|
||||
setup(
|
||||
name='deepspeech_training',
|
||||
version=version,
|
||||
description='Training code for mozilla DeepSpeech',
|
||||
url='https://github.com/mozilla/DeepSpeech',
|
||||
author='Mozilla',
|
||||
license='MPL-2.0',
|
||||
# Classifiers help users find your project by categorizing it.
|
||||
#
|
||||
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Multimedia :: Sound/Audio :: Speech',
|
||||
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
|
||||
'Programming Language :: Python :: 3',
|
||||
],
|
||||
package_dir={'': 'training'},
|
||||
packages=find_packages(where='training'),
|
||||
python_requires='>=3.5, <4',
|
||||
install_requires=install_requires,
|
||||
# If there are data files included in your packages that need to be
|
||||
# installed, specify them here.
|
||||
package_data={
|
||||
'deepspeech_training': [
|
||||
'VERSION',
|
||||
'GRAPH_VERSION',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
43
stats.py
43
stats.py
@ -1,10 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import functools
|
||||
import pandas
|
||||
|
||||
from deepspeech_training.util.helpers import secs_to_hours
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_csvs(csv_files):
|
||||
# Relative paths are relative to CSV location
|
||||
def absolutify(csv, path):
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
return str(csv.parent / path)
|
||||
|
||||
sets = []
|
||||
for csv in csv_files:
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
|
||||
sets.append(file)
|
||||
|
||||
# Concat all sets, drop any extra columns, re-index the final result as 0..N
|
||||
return pandas.concat(sets, join='inner', ignore_index=True)
|
||||
|
||||
from util.helpers import secs_to_hours
|
||||
from util.feeding import read_csvs
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -14,20 +33,16 @@ def main():
|
||||
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
|
||||
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
|
||||
args = parser.parse_args()
|
||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
|
||||
|
||||
csv_dataframe = read_csvs(in_files)
|
||||
total_bytes = csv_dataframe['wav_filesize'].sum()
|
||||
total_files = len(csv_dataframe.index)
|
||||
total_files = len(csv_dataframe)
|
||||
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
|
||||
|
||||
bytes_without_headers = total_bytes - 44 * total_files
|
||||
|
||||
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8)
|
||||
|
||||
print('total_bytes', total_bytes)
|
||||
print('total_files', total_files)
|
||||
print('bytes_without_headers', bytes_without_headers)
|
||||
print('total_time', secs_to_hours(total_time))
|
||||
print('Total bytes:', total_bytes)
|
||||
print('Total files:', total_files)
|
||||
print('Total time:', secs_to_hours(total_seconds))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -115,10 +115,6 @@ system:
|
||||
swig:
|
||||
repo: "https://github.com/lissyx/swig"
|
||||
sha1: "b5fea54d39832d1d132d7dd921b69c0c2c9d5118"
|
||||
cache:
|
||||
linux_amd64: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.linux.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz'
|
||||
darwin_amd64: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.darwin.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz'
|
||||
win_amd64: 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.swig.win.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118/artifacts/public/ds-swig.tar.gz'
|
||||
username: 'build-user'
|
||||
homedir:
|
||||
linux: '/home/build-user'
|
||||
|
@ -88,10 +88,6 @@ payload:
|
||||
format: tar.gz
|
||||
content:
|
||||
url: ${system.node_gyp_cache.url}
|
||||
- directory: ds-swig
|
||||
format: tar.gz
|
||||
content:
|
||||
url: ${system.swig.cache.darwin_amd64}
|
||||
- file: home.tar.xz
|
||||
content:
|
||||
url: ${build.tensorflow}
|
||||
|
@ -1,6 +1,6 @@
|
||||
breathe==4.13.1
|
||||
breathe==4.14.2
|
||||
semver==2.8.1
|
||||
sphinx==2.2.0
|
||||
sphinx==2.4.4
|
||||
sphinx-js==2.8
|
||||
sphinx-rtd-theme==0.4.3
|
||||
pygments==2.4.2
|
||||
pygments==2.6.1
|
||||
|
@ -48,7 +48,7 @@ then:
|
||||
adduser --system --home ${system.homedir.linux} ${system.username} &&
|
||||
apt-get -qq update && apt-get -qq -y install ${tensorflow.packages_trusty.apt} pixz pkg-config realpath unzip wget zip && ${extraSystemSetup} &&
|
||||
cd ${system.homedir.linux}/ &&
|
||||
echo -e "#!/bin/bash\nset -xe\n env && id && (wget -O - $TENSORFLOW_BUILD_ARTIFACT | pixz -d | tar -C ${system.homedir.linux}/ -xf - ) && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && ln -s ~/DeepSpeech/ds/native_client/ ~/DeepSpeech/tf/native_client && mkdir -p ${system.homedir.linux}/.cache/node-gyp/ && wget -O - ${system.node_gyp_cache.url} | tar -C ${system.homedir.linux}/.cache/node-gyp/ -xzf - && mkdir -p ${system.homedir.linux}/ds-swig/bin/ && wget -O - ${system.swig.cache.linux_amd64} | tar -C ${system.homedir.linux}/ds-swig/ -xzf - && mkdir -p ${system.homedir.linux}/pyenv-root/ && wget -O - ${system.pyenv.linux.url} | tar -C ${system.homedir.linux}/pyenv-root/ -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi && if [ ! -z "${build.android_cache.url}" ]; then wget -O - ${build.android_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
|
||||
echo -e "#!/bin/bash\nset -xe\n env && id && (wget -O - $TENSORFLOW_BUILD_ARTIFACT | pixz -d | tar -C ${system.homedir.linux}/ -xf - ) && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && ln -s ~/DeepSpeech/ds/native_client/ ~/DeepSpeech/tf/native_client && mkdir -p ${system.homedir.linux}/.cache/node-gyp/ && wget -O - ${system.node_gyp_cache.url} | tar -C ${system.homedir.linux}/.cache/node-gyp/ -xzf - && mkdir -p ${system.homedir.linux}/pyenv-root/ && wget -O - ${system.pyenv.linux.url} | tar -C ${system.homedir.linux}/pyenv-root/ -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi && if [ ! -z "${build.android_cache.url}" ]; then wget -O - ${build.android_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
|
||||
sudo -H -u ${system.username} /bin/bash /tmp/clone.sh && ${extraSystemConfig} &&
|
||||
sudo -H -u ${system.username} --preserve-env /bin/bash ${system.homedir.linux}/DeepSpeech/ds/${build.scripts.build} &&
|
||||
sudo -H -u ${system.username} /bin/bash ${system.homedir.linux}/DeepSpeech/ds/${build.scripts.package}
|
||||
|
@ -122,25 +122,3 @@ verify_bazel_rebuild()
|
||||
exit 1
|
||||
fi;
|
||||
}
|
||||
|
||||
# Should be called from context where Python virtualenv is set
|
||||
verify_ctcdecoder_url()
|
||||
{
|
||||
default_url=$(python util/taskcluster.py --decoder)
|
||||
echo "${default_url}" | grep -F "deepspeech.native_client.v${DS_VERSION}"
|
||||
rc_default_url=$?
|
||||
|
||||
tag_url=$(python util/taskcluster.py --decoder --branch 'v1.2.3')
|
||||
echo "${tag_url}" | grep -F "deepspeech.native_client.v1.2.3"
|
||||
rc_tag_url=$?
|
||||
|
||||
master_url=$(python util/taskcluster.py --decoder --branch 'master')
|
||||
echo "${master_url}" | grep -F "deepspeech.native_client.master"
|
||||
rc_master_url=$?
|
||||
|
||||
if [ ${rc_default_url} -eq 0 -a ${rc_tag_url} -eq 0 -a ${rc_master_url} -eq 0 ]; then
|
||||
return 0
|
||||
else
|
||||
return 1
|
||||
fi;
|
||||
}
|
||||
|
@ -6,14 +6,12 @@ export OS=$(uname)
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export DS_ROOT_TASK=${HOME}
|
||||
export PYENV_ROOT="${DS_ROOT_TASK}/pyenv-root"
|
||||
export SWIG_ROOT="${HOME}/ds-swig"
|
||||
export DS_CPU_COUNT=$(nproc)
|
||||
fi;
|
||||
|
||||
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
|
||||
export DS_ROOT_TASK=${TASKCLUSTER_TASK_DIR}
|
||||
export PYENV_ROOT="${TASKCLUSTER_TASK_DIR}/pyenv-root"
|
||||
export SWIG_ROOT="$(cygpath ${USERPROFILE})/ds-swig"
|
||||
export PLATFORM_EXE_SUFFIX=.exe
|
||||
export DS_CPU_COUNT=$(nproc)
|
||||
|
||||
@ -22,7 +20,6 @@ if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
|
||||
fi;
|
||||
|
||||
if [ "${OS}" = "Darwin" ]; then
|
||||
export SWIG_ROOT="${TASKCLUSTER_ORIG_TASKDIR}/ds-swig"
|
||||
export DS_ROOT_TASK=${TASKCLUSTER_TASK_DIR}
|
||||
export DS_CPU_COUNT=$(sysctl hw.ncpu |cut -d' ' -f2)
|
||||
export PYENV_ROOT="${DS_ROOT_TASK}/pyenv-root"
|
||||
@ -45,19 +42,6 @@ if [ "${OS}" = "Darwin" ]; then
|
||||
fi;
|
||||
fi;
|
||||
|
||||
SWIG_BIN=swig${PLATFORM_EXE_SUFFIX}
|
||||
DS_SWIG_BIN=ds-swig${PLATFORM_EXE_SUFFIX}
|
||||
if [ -f "${SWIG_ROOT}/bin/${DS_SWIG_BIN}" ]; then
|
||||
export PATH=${SWIG_ROOT}/bin/:$PATH
|
||||
export SWIG_LIB="$(find ${SWIG_ROOT}/share/swig/ -type f -name "swig.swg" | xargs dirname)"
|
||||
# Make an alias to be more magic
|
||||
if [ ! -L "${SWIG_ROOT}/bin/${SWIG_BIN}" ]; then
|
||||
ln -s ${DS_SWIG_BIN} ${SWIG_ROOT}/bin/${SWIG_BIN}
|
||||
fi;
|
||||
swig -version
|
||||
swig -swiglib
|
||||
fi;
|
||||
|
||||
PY37_OPENSSL_DIR="${PYENV_ROOT}/ssl-xenial"
|
||||
export PY37_LDPATH="${PY37_OPENSSL_DIR}/usr/lib/"
|
||||
export LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH
|
||||
|
@ -105,16 +105,10 @@ do_deepspeech_nodejs_build()
|
||||
# Python 2.7 is required for node-pre-gyp, it is only required to force it on
|
||||
# Windows
|
||||
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
|
||||
NPM_ROOT=$(cygpath -u "$(npm root)")
|
||||
PYTHON27=":/c/Python27"
|
||||
# node-gyp@5.x behaves erratically with VS2015 and MSBuild.exe detection
|
||||
npm install node-gyp@4.x node-pre-gyp
|
||||
else
|
||||
NPM_ROOT="$(npm root)"
|
||||
npm install node-gyp@5.x node-pre-gyp
|
||||
PYTHON27="/c/Python27"
|
||||
fi
|
||||
|
||||
export PATH="$NPM_ROOT/.bin/${PYTHON27}:$PATH"
|
||||
export PATH="${PYTHON27}:$PATH"
|
||||
|
||||
for node in ${SUPPORTED_NODEJS_VERSIONS}; do
|
||||
EXTRA_CFLAGS="${EXTRA_LOCAL_CFLAGS}" EXTRA_LDFLAGS="${EXTRA_LOCAL_LDFLAGS}" EXTRA_LIBS="${EXTRA_LOCAL_LIBS}" make -C native_client/javascript \
|
||||
@ -157,16 +151,10 @@ do_deepspeech_npm_package()
|
||||
# Python 2.7 is required for node-pre-gyp, it is only required to force it on
|
||||
# Windows
|
||||
if [ "${OS}" = "${TC_MSYS_VERSION}" ]; then
|
||||
NPM_ROOT=$(cygpath -u "$(npm root)")
|
||||
PYTHON27=":/c/Python27"
|
||||
# node-gyp@5.x behaves erratically with VS2015 and MSBuild.exe detection
|
||||
npm install node-gyp@4.x node-pre-gyp
|
||||
else
|
||||
NPM_ROOT="$(npm root)"
|
||||
npm install node-gyp@5.x node-pre-gyp
|
||||
PYTHON27="/c/Python27"
|
||||
fi
|
||||
|
||||
export PATH="$NPM_ROOT/.bin/$PYTHON27:$PATH"
|
||||
export PATH="${NPM_BIN}${PYTHON27}:$PATH"
|
||||
|
||||
all_tasks="$(curl -s https://community-tc.services.mozilla.com/api/queue/v1/task/${TASK_ID} | python -c 'import json; import sys; print(" ".join(json.loads(sys.stdin.read())["dependencies"]));')"
|
||||
|
||||
|
@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
|
||||
set -o pipefail
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
which deepspeech
|
||||
|
@ -17,12 +17,11 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${decoder_pkg_url} | cat
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
time ./bin/run-tc-ldc93s1_singleshotinference.sh
|
||||
popd
|
||||
|
@ -16,15 +16,10 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
verify_ctcdecoder_url
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
|
||||
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
|
||||
set +o pipefail
|
||||
|
||||
# Prepare correct arguments for training
|
||||
case "${bitrate}" in
|
||||
|
@ -14,15 +14,10 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
verify_ctcdecoder_url
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
DS_NODECODER=1 pip install --upgrade . | cat
|
||||
popd
|
||||
|
||||
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
time ./bin/run-tc-transfer.sh
|
||||
|
@ -50,7 +50,7 @@ then:
|
||||
${extraSystemSetup} && chmod 777 /dev/kvm &&
|
||||
adduser --system --home ${system.homedir.linux} ${system.username} &&
|
||||
cd ${system.homedir.linux} &&
|
||||
echo -e "#!/bin/bash\nset -xe\n env && id && mkdir ~/DeepSpeech/ && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && mkdir -p ${system.homedir.linux}/ds-swig/bin/ && wget -O - ${system.swig.cache.linux_amd64} | tar -C ${system.homedir.linux}/ds-swig/ -xzf - && wget -O - ${build.cache.url} | tar -C ${system.homedir.linux} -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
|
||||
echo -e "#!/bin/bash\nset -xe\n env && id && mkdir ~/DeepSpeech/ && git clone --quiet ${event.head.repo.url} ~/DeepSpeech/ds/ && cd ~/DeepSpeech/ds && git checkout --quiet ${event.head.sha} && wget -O - ${build.cache.url} | tar -C ${system.homedir.linux} -xzf - && if [ ! -z "${build.gradle_cache.url}" ]; then wget -O - ${build.gradle_cache.url} | tar -C ${system.homedir.linux}/ -xzf - ; fi;" > /tmp/clone.sh && chmod +x /tmp/clone.sh &&
|
||||
sudo -H -u ${system.username} /bin/bash /tmp/clone.sh &&
|
||||
sudo -H -u ${system.username} --preserve-env /bin/bash ${build.args.tests_cmdline}
|
||||
|
||||
|
@ -85,10 +85,6 @@ payload:
|
||||
format: tar.gz
|
||||
content:
|
||||
url: ${system.node_gyp_cache.url}
|
||||
- directory: ds-swig
|
||||
format: tar.gz
|
||||
content:
|
||||
url: ${system.swig.cache.win_amd64}
|
||||
|
||||
artifacts:
|
||||
- type: "directory"
|
||||
|
@ -1,10 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from argparse import Namespace
|
||||
from .importers import validate_label_eng, get_validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng, get_validate_label
|
||||
from pathlib import Path
|
||||
|
||||
def from_here(path):
|
||||
here = Path(__file__)
|
||||
return here.parent / path
|
||||
|
||||
class TestValidateLabelEng(unittest.TestCase):
|
||||
|
||||
def test_numbers(self):
|
||||
label = validate_label_eng("this is a 1 2 3 test")
|
||||
self.assertEqual(label, None)
|
||||
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
|
||||
self.assertEqual(f('toto1234[{[{[]'), None)
|
||||
|
||||
def test_get_validate_label_missing(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
|
||||
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
|
||||
f = get_validate_label(args)
|
||||
self.assertEqual(f, None)
|
||||
|
||||
def test_get_validate_label(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
|
||||
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
|
||||
f = get_validate_label(args)
|
||||
l = f('toto')
|
||||
self.assertEqual(l, 'toto')
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from .text import Alphabet
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
class TestAlphabetParsing(unittest.TestCase):
|
||||
|
1
training/deepspeech_training/GRAPH_VERSION
Symbolic link
1
training/deepspeech_training/GRAPH_VERSION
Symbolic link
@ -0,0 +1 @@
|
||||
../../GRAPH_VERSION
|
1
training/deepspeech_training/VERSION
Symbolic link
1
training/deepspeech_training/VERSION
Symbolic link
@ -0,0 +1 @@
|
||||
../../VERSION
|
0
training/deepspeech_training/__init__.py
Normal file
0
training/deepspeech_training/__init__.py
Normal file
155
training/deepspeech_training/evaluate.py
Executable file
155
training/deepspeech_training/evaluate.py
Executable file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values, converting tokens to strings using ``alphabet``.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [[] for _ in range(sp_tuple[2][0])]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x,
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
seq_length=batch_x_len,
|
||||
dropout=no_dropout)
|
||||
|
||||
# Transpose to batch major and apply softmax for decoder
|
||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||
|
||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
bar = create_progressbar(prefix='Test epoch | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Test epoch...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
bar.update(step_count)
|
||||
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
print('Testing model on {}'.format(csv))
|
||||
samples.extend(run_test(init_op, dataset=csv))
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_script()
|
942
training/deepspeech_training/train.py
Normal file
942
training/deepspeech_training/train.py
Normal file
@ -0,0 +1,942 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
output, output_state = fw_cell(inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode='linear_input',
|
||||
direction='unidirectional',
|
||||
dtype=tf.float32)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||
sequence_lengths=seq_length)
|
||||
|
||||
return output, output_state
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope='cell_0')
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers['input_reshaped'] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers['rnn_output'] = output
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||
layers['raw_logits'] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
|
||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# the loss function used by our network should be the CTC loss function
|
||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
# Calculate the logits of the batch
|
||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
# =================
|
||||
|
||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||
# because, generally, it requires less fine-tuning.
|
||||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
return optimizer
|
||||
|
||||
|
||||
# Towers
|
||||
# ======
|
||||
|
||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||
# In particular, one must introduce a means to isolate the inference and gradient
|
||||
# calculations on the various GPU's.
|
||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||
# A tower is specified by two properties:
|
||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||
# is a means to isolate the operations within a tower.
|
||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||
# on which all operations within the tower execute.
|
||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||
|
||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
r'''
|
||||
With this preliminary step out of the way, we can for each GPU introduce a
|
||||
tower for which's batch we calculate and return the optimization gradients
|
||||
and the average loss across towers.
|
||||
'''
|
||||
# To calculate the mean of the losses
|
||||
tower_avg_losses = []
|
||||
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain tower's avg losses
|
||||
tower_avg_losses.append(avg_loss)
|
||||
|
||||
# Compute gradients for model parameters using tower's mini-batch
|
||||
gradients = optimizer.compute_gradients(avg_loss)
|
||||
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
r'''
|
||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||
Note also that this code acts as a synchronization point as it requires all
|
||||
GPUs to be finished with their mini-batch before it can run to completion.
|
||||
'''
|
||||
# List of average gradients to return to the caller
|
||||
average_grads = []
|
||||
|
||||
# Run this on cpu_device to conserve GPU memory
|
||||
with tf.device(Config.cpu_device):
|
||||
# Loop over gradient/variable pairs from all towers
|
||||
for grad_and_vars in zip(*tower_gradients):
|
||||
# Introduce grads to store the gradients for the current variable
|
||||
grads = []
|
||||
|
||||
# Loop over the gradients for the current variable
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension
|
||||
grad = tf.concat(grads, 0)
|
||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||
|
||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||
grad_and_var = (grad, grad_and_vars[0][1])
|
||||
|
||||
# Add the current tuple to average_grads
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
# Return result to caller
|
||||
return average_grads
|
||||
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r'''
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
'''
|
||||
name = variable.name.replace(':', '_')
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r'''
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
'''
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
load_or_init_graph_for_training(session)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in DS_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r'''
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||
FLAGS.export_author_id,
|
||||
FLAGS.export_model_name,
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||
f.write('runtime: {}\n'.format(model_runtime))
|
||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||
f.write('---\n')
|
||||
f.write('{}\n'.format(FLAGS.export_description))
|
||||
|
||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def early_training_checks():
|
||||
# Check for proper scorer early
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
del scorer
|
||||
|
||||
if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir:
|
||||
log_warn('WARNING: You specified different values for --load_checkpoint_dir '
|
||||
'and --save_checkpoint_dir, but you are running training and testing '
|
||||
'in a single invocation. The testing step will respect --load_checkpoint_dir, '
|
||||
'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. '
|
||||
'Train and test in two separate invocations, specifying the correct '
|
||||
'--load_checkpoint_dir in both cases, or use the same location '
|
||||
'for loading and saving.')
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
early_training_checks()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_script()
|
0
training/deepspeech_training/util/__init__.py
Normal file
0
training/deepspeech_training/util/__init__.py
Normal file
@ -5,7 +5,7 @@ import tempfile
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from util.helpers import LimitingPool
|
||||
from .helpers import LimitingPool
|
||||
|
||||
DEFAULT_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
@ -2,11 +2,11 @@ import sys
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.logging import log_info, log_error, log_warn
|
||||
from .flags import FLAGS
|
||||
from .logging import log_info, log_error, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path):
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
|
||||
# Load the checkpoint and put all variables into loading list
|
||||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
@ -45,7 +45,7 @@ def _load_checkpoint(session, checkpoint_path):
|
||||
'tensors. Missing variables: {}'.format(missing_var_names))
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.drop_source_layers > 0:
|
||||
if allow_drop_layers and FLAGS.drop_source_layers > 0:
|
||||
# This transfer learning approach requires supplying
|
||||
# the layers which we exclude from the source model.
|
||||
# Say we want to exclude all layers except for the first one,
|
||||
@ -87,20 +87,14 @@ def _initialize_all_variables(session):
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def load_or_init_graph(session, method_order):
|
||||
'''
|
||||
Load variables from checkpoint or initialize variables following the method
|
||||
order specified in the method_order parameter.
|
||||
|
||||
Valid methods are 'best', 'last' and 'init'.
|
||||
'''
|
||||
def _load_or_init_impl(session, method_order, allow_drop_layers):
|
||||
for method in method_order:
|
||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||
if method == 'best':
|
||||
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path)
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
||||
log_info('Could not find best validating checkpoint.')
|
||||
|
||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||
@ -108,7 +102,7 @@ def load_or_init_graph(session, method_order):
|
||||
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path)
|
||||
return _load_checkpoint(session, ckpt_path, allow_drop_layers)
|
||||
log_info('Could not find most recent checkpoint.')
|
||||
|
||||
# Initialize all variables
|
||||
@ -122,3 +116,31 @@ def load_or_init_graph(session, method_order):
|
||||
|
||||
log_error('All initialization methods failed ({}).'.format(method_order))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_or_init_graph_for_training(session):
|
||||
'''
|
||||
Load variables from checkpoint or initialize variables. By default this will
|
||||
try to load the best validating checkpoint, then try the last checkpoint,
|
||||
and finally initialize the weights from scratch. This can be overriden with
|
||||
the `--load_train` flag. See its documentation for more info.
|
||||
'''
|
||||
if FLAGS.load_train == 'auto':
|
||||
methods = ['best', 'last', 'init']
|
||||
else:
|
||||
methods = [FLAGS.load_train]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
||||
|
||||
|
||||
def load_graph_for_evaluation(session):
|
||||
'''
|
||||
Load variables from checkpoint. Initialization is not allowed. By default
|
||||
this will try to load the best validating checkpoint, then try the last
|
||||
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
|
||||
documentation for more info.
|
||||
'''
|
||||
if FLAGS.load_evaluate == 'auto':
|
||||
methods = ['best', 'last']
|
||||
else:
|
||||
methods = [FLAGS.load_evaluate]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False)
|
@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
|
||||
from attrdict import AttrDict
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from util.helpers import parse_file_size
|
||||
from .flags import FLAGS
|
||||
from .gpu import get_available_gpus
|
||||
from .logging import log_error, log_warn
|
||||
from .text import Alphabet, UTF8Alphabet
|
||||
from .helpers import parse_file_size
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
@ -45,8 +45,11 @@ def initialize_globals():
|
||||
if not FLAGS.checkpoint_dir:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
|
||||
|
||||
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load = 'auto'
|
||||
if FLAGS.load_train not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load_train = 'auto'
|
||||
|
||||
if FLAGS.load_evaluate not in ['last', 'best', 'auto']:
|
||||
FLAGS.load_evaluate = 'auto'
|
||||
|
||||
# Set default summary dir
|
||||
if not FLAGS.summary_dir:
|
@ -7,8 +7,8 @@ import numpy as np
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.text import levenshtein
|
||||
from .flags import FLAGS
|
||||
from .text import levenshtein
|
||||
|
||||
|
||||
def pmap(fun, iterable):
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user