Package training code to avoid sys.path hacks
This commit is contained in:
parent
58bc2f2bb1
commit
a05baa35c9
@ -9,7 +9,7 @@ python:
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- stage: cardboard linter
|
||||
- name: cardboard linter
|
||||
install:
|
||||
- pip install --upgrade cardboardlint pylint
|
||||
script:
|
||||
@ -17,9 +17,10 @@ jobs:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
|
||||
fi
|
||||
- stage: python unit tests
|
||||
- name: python unit tests
|
||||
install:
|
||||
- pip install --upgrade -r requirements_tests.txt
|
||||
- pip install --upgrade -r requirements_tests.txt;
|
||||
pip install --upgrade .
|
||||
script:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
python -m unittest;
|
||||
|
937
DeepSpeech.py
937
DeepSpeech.py
@ -2,934 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
output, output_state = fw_cell(inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode='linear_input',
|
||||
direction='unidirectional',
|
||||
dtype=tf.float32)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||
sequence_lengths=seq_length)
|
||||
|
||||
return output, output_state
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope='cell_0')
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers['input_reshaped'] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers['rnn_output'] = output
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||
layers['raw_logits'] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
|
||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# the loss function used by our network should be the CTC loss function
|
||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
# Calculate the logits of the batch
|
||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
# =================
|
||||
|
||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||
# because, generally, it requires less fine-tuning.
|
||||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
return optimizer
|
||||
|
||||
|
||||
# Towers
|
||||
# ======
|
||||
|
||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||
# In particular, one must introduce a means to isolate the inference and gradient
|
||||
# calculations on the various GPU's.
|
||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||
# A tower is specified by two properties:
|
||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||
# is a means to isolate the operations within a tower.
|
||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||
# on which all operations within the tower execute.
|
||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||
|
||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
r'''
|
||||
With this preliminary step out of the way, we can for each GPU introduce a
|
||||
tower for which's batch we calculate and return the optimization gradients
|
||||
and the average loss across towers.
|
||||
'''
|
||||
# To calculate the mean of the losses
|
||||
tower_avg_losses = []
|
||||
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain tower's avg losses
|
||||
tower_avg_losses.append(avg_loss)
|
||||
|
||||
# Compute gradients for model parameters using tower's mini-batch
|
||||
gradients = optimizer.compute_gradients(avg_loss)
|
||||
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
r'''
|
||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||
Note also that this code acts as a synchronization point as it requires all
|
||||
GPUs to be finished with their mini-batch before it can run to completion.
|
||||
'''
|
||||
# List of average gradients to return to the caller
|
||||
average_grads = []
|
||||
|
||||
# Run this on cpu_device to conserve GPU memory
|
||||
with tf.device(Config.cpu_device):
|
||||
# Loop over gradient/variable pairs from all towers
|
||||
for grad_and_vars in zip(*tower_gradients):
|
||||
# Introduce grads to store the gradients for the current variable
|
||||
grads = []
|
||||
|
||||
# Loop over the gradients for the current variable
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension
|
||||
grad = tf.concat(grads, 0)
|
||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||
|
||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||
grad_and_var = (grad, grad_and_vars[0][1])
|
||||
|
||||
# Add the current tuple to average_grads
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
# Return result to caller
|
||||
return average_grads
|
||||
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r'''
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
'''
|
||||
name = variable.name.replace(':', '_')
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r'''
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
'''
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in DS_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r'''
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||
FLAGS.export_author_id,
|
||||
FLAGS.export_model_name,
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||
f.write('runtime: {}\n'.format(model_runtime))
|
||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||
f.write('---\n')
|
||||
f.write('{}\n'.format(FLAGS.export_description))
|
||||
|
||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
try:
|
||||
from deepspeech_training import train as ds_train
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
ds_train.run_script()
|
||||
|
@ -150,7 +150,7 @@ COPY . /DeepSpeech/
|
||||
|
||||
WORKDIR /DeepSpeech
|
||||
|
||||
RUN pip3 --no-cache-dir install -r requirements.txt
|
||||
RUN pip3 --no-cache-dir install .
|
||||
|
||||
# Link DeepSpeech native_client libs to tf folder
|
||||
RUN ln -s /DeepSpeech/native_client /tensorflow
|
||||
|
@ -5,18 +5,12 @@ Use "python3 build_sdb.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import argparse
|
||||
import progressbar
|
||||
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
||||
from util.sample_collections import samples_from_files, DirectSDBWriter
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.audio import change_audio_types, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS
|
||||
from deepspeech_training.util.sample_collections import samples_from_files, DirectSDBWriter
|
||||
|
||||
AUDIO_TYPE_LOOKUP = {
|
||||
'wav': AUDIO_TYPE_WAV,
|
||||
|
@ -1,11 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
from util.gpu_usage import GPUUsage
|
||||
from deepspeech_training.util.gpu_usage import GPUUsage
|
||||
|
||||
gu = GPUUsage()
|
||||
gu.start()
|
||||
|
@ -1,10 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
sys.path.append(os.path.abspath('.'))
|
||||
|
||||
from util.gpu_usage import GPUUsageChart
|
||||
from deepspeech_training.util.gpu_usage import GPUUsageChart
|
||||
|
||||
GPUUsageChart(sys.argv[1], sys.argv[2])
|
||||
|
@ -1,17 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import os
|
||||
import pandas
|
||||
import tarfile
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
|
@ -1,17 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import tarfile
|
||||
import os
|
||||
import pandas
|
||||
import tarfile
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMNNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
|
@ -1,23 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import sox
|
||||
import tarfile
|
||||
import subprocess
|
||||
import os
|
||||
import progressbar
|
||||
import sox
|
||||
import subprocess
|
||||
import tarfile
|
||||
|
||||
from glob import glob
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
SAMPLE_RATE = 16000
|
||||
@ -28,7 +22,7 @@ ARCHIVE_URL = 'https://s3.us-east-2.amazonaws.com/common-voice-data-download/' +
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract common voice data
|
||||
@ -38,8 +32,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
with tarfile.open(archive_path) as tar:
|
||||
tar.extractall(target_dir)
|
||||
@ -47,9 +41,9 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
print('Found directory "%s" - not extracting it from archive.' % extracted_path)
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(path.join(extracted_dir, '*.csv')):
|
||||
_maybe_convert_set(extracted_dir, source_csv, path.join(target_dir, os.path.split(source_csv)[-1]))
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
for source_csv in glob(os.path.join(extracted_dir, '*.csv')):
|
||||
_maybe_convert_set(extracted_dir, source_csv, os.path.join(target_dir, os.path.split(source_csv)[-1]))
|
||||
|
||||
def one_sample(sample):
|
||||
mp3_filename = sample[0]
|
||||
@ -58,7 +52,7 @@ def one_sample(sample):
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
file_size = -1
|
||||
if path.exists(wav_filename):
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = validate_label(sample[1])
|
||||
@ -85,7 +79,7 @@ def one_sample(sample):
|
||||
|
||||
def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
print()
|
||||
if path.exists(target_csv):
|
||||
if os.path.exists(target_csv):
|
||||
print('Found CSV file "%s" - not importing "%s".' % (target_csv, source_csv))
|
||||
return
|
||||
print('No CSV file "%s" - importing "%s"...' % (target_csv, source_csv))
|
||||
@ -126,7 +120,7 @@ def _maybe_convert_set(extracted_dir, source_csv, target_csv):
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
|
@ -8,23 +8,17 @@ Use "python3 import_cv2.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import os
|
||||
import progressbar
|
||||
import sox
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
|
||||
from os import path
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.text import Alphabet
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
@ -34,7 +28,7 @@ MAX_SECS = 10
|
||||
|
||||
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
for dataset in ['train', 'test', 'dev', 'validated', 'other']:
|
||||
input_tsv = path.join(path.abspath(tsv_dir), dataset+".tsv")
|
||||
input_tsv = os.path.join(os.path.abspath(tsv_dir), dataset+".tsv")
|
||||
if os.path.isfile(input_tsv):
|
||||
print("Loading TSV file: ", input_tsv)
|
||||
_maybe_convert_set(input_tsv, audio_dir, space_after_every_character)
|
||||
@ -42,15 +36,15 @@ def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
|
||||
def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
mp3_filename = sample[0]
|
||||
if not path.splitext(mp3_filename.lower())[1] == '.mp3':
|
||||
if not os.path.splitext(mp3_filename.lower())[1] == '.mp3':
|
||||
mp3_filename += ".mp3"
|
||||
# Storing wav files next to the mp3 ones - just with a different suffix
|
||||
wav_filename = path.splitext(mp3_filename)[0] + ".wav"
|
||||
wav_filename = os.path.splitext(mp3_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(mp3_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = label_filter_fun(sample[1])
|
||||
rows = []
|
||||
@ -76,7 +70,7 @@ def one_sample(sample):
|
||||
return (counter, rows)
|
||||
|
||||
def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
output_csv = path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||
output_csv = os.path.join(audio_dir, os.path.split(input_tsv)[-1].replace('tsv', 'csv'))
|
||||
print("Saving new DeepSpeech-formatted CSV file to: ", output_csv)
|
||||
|
||||
# Get audiofile path and transcript for each sentence in tsv
|
||||
@ -84,7 +78,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
with open(input_tsv, encoding='utf-8') as input_tsv_file:
|
||||
reader = csv.DictReader(input_tsv_file, delimiter='\t')
|
||||
for row in reader:
|
||||
samples.append((path.join(audio_dir, row['path']), row['sentence']))
|
||||
samples.append((os.path.join(audio_dir, row['path']), row['sentence']))
|
||||
|
||||
counter = get_counter()
|
||||
num_samples = len(samples)
|
||||
@ -120,7 +114,7 @@ def _maybe_convert_set(input_tsv, audio_dir, space_after_every_character=None):
|
||||
|
||||
|
||||
def _maybe_convert_wav(mp3_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
|
@ -4,22 +4,17 @@ from __future__ import absolute_import, division, print_function
|
||||
# Prerequisite: Having the sph2pipe tool in your PATH:
|
||||
# https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import librosa
|
||||
import os
|
||||
import pandas
|
||||
import subprocess
|
||||
import unicodedata
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
import subprocess
|
||||
import sys
|
||||
import unicodedata
|
||||
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
|
@ -1,18 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
import pandas
|
||||
import tarfile
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
|
@ -1,22 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import math
|
||||
import urllib
|
||||
import logging
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
import subprocess
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
|
||||
import swifter
|
||||
import math
|
||||
import os
|
||||
import pandas as pd
|
||||
import swifter
|
||||
import subprocess
|
||||
import urllib
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label
|
||||
from pathlib import Path
|
||||
from sox import Transformer
|
||||
|
||||
|
||||
@ -142,11 +136,11 @@ class GramVaaniDownloader:
|
||||
return mp3_directory
|
||||
|
||||
def _pre_download(self):
|
||||
mp3_directory = path.join(self.target_dir, "mp3")
|
||||
if not path.exists(self.target_dir):
|
||||
mp3_directory = os.path.join(self.target_dir, "mp3")
|
||||
if not os.path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(mp3_directory):
|
||||
if not os.path.exists(mp3_directory):
|
||||
_logger.info("Creating directory...%s", mp3_directory)
|
||||
os.mkdir(mp3_directory)
|
||||
return mp3_directory
|
||||
@ -154,8 +148,8 @@ class GramVaaniDownloader:
|
||||
def _download(self, audio_url, transcript, audio_length, mp3_directory):
|
||||
if audio_url == "audio_url":
|
||||
return
|
||||
mp3_filename = path.join(mp3_directory, os.path.basename(audio_url))
|
||||
if not path.exists(mp3_filename):
|
||||
mp3_filename = os.path.join(mp3_directory, os.path.basename(audio_url))
|
||||
if not os.path.exists(mp3_filename):
|
||||
_logger.debug("Downloading mp3 file...%s", audio_url)
|
||||
urllib.request.urlretrieve(audio_url, mp3_filename)
|
||||
else:
|
||||
@ -182,8 +176,8 @@ class GramVaaniConverter:
|
||||
"""
|
||||
wav_directory = self._pre_convert()
|
||||
for mp3_filename in self.mp3_directory.glob('**/*.mp3'):
|
||||
wav_filename = path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
if not path.exists(wav_filename):
|
||||
wav_filename = os.path.join(wav_directory, os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
if not os.path.exists(wav_filename):
|
||||
_logger.debug("Converting mp3 file %s to wav file %s" % (mp3_filename, wav_filename))
|
||||
transformer = Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH)
|
||||
@ -193,11 +187,11 @@ class GramVaaniConverter:
|
||||
return wav_directory
|
||||
|
||||
def _pre_convert(self):
|
||||
wav_directory = path.join(self.target_dir, "wav")
|
||||
if not path.exists(self.target_dir):
|
||||
wav_directory = os.path.join(self.target_dir, "wav")
|
||||
if not os.path.exists(self.target_dir):
|
||||
_logger.info("Creating directory...%s", self.target_dir)
|
||||
os.mkdir(self.target_dir)
|
||||
if not path.exists(wav_directory):
|
||||
if not os.path.exists(wav_directory):
|
||||
_logger.info("Creating directory...%s", wav_directory)
|
||||
os.mkdir(wav_directory)
|
||||
return wav_directory
|
||||
@ -233,8 +227,8 @@ class GramVaaniDataSets:
|
||||
if audio_url == "audio_url":
|
||||
return pd.Series(["wav_filename", "wav_filesize", "transcript"])
|
||||
mp3_filename = os.path.basename(audio_url)
|
||||
wav_relative_filename = path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
wav_filesize = path.getsize(path.join(self.target_dir, wav_relative_filename))
|
||||
wav_relative_filename = os.path.join("wav", os.path.splitext(os.path.basename(mp3_filename))[0] + ".wav")
|
||||
wav_filesize = os.path.getsize(os.path.join(self.target_dir, wav_relative_filename))
|
||||
transcript = validate_label(transcript)
|
||||
if None == transcript:
|
||||
transcript = ""
|
||||
@ -252,7 +246,7 @@ class GramVaaniDataSets:
|
||||
|
||||
def _is_valid_raw_wav_frames(self):
|
||||
transcripts = [str(transcript) for transcript in self.raw.transcript]
|
||||
wav_filepaths = [path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_filepaths = [os.path.join(self.target_dir, str(wav_filename)) for wav_filename in self.raw.wav_filename]
|
||||
wav_frames = [int(subprocess.check_output(['soxi', '-s', wav_filepath], stderr=subprocess.STDOUT)) for wav_filepath in wav_filepaths]
|
||||
is_valid_raw_wav_frames = [self._is_wav_frame_valid(wav_frame, transcript) for wav_frame, transcript in zip(wav_frames, transcripts)]
|
||||
return pd.Series(is_valid_raw_wav_frames)
|
||||
|
@ -1,15 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import pandas
|
||||
import os
|
||||
import sys
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
|
@ -1,22 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import pandas
|
||||
import progressbar
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from sox import Transformer
|
||||
from util.downloader import maybe_download
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
@ -1,31 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import progressbar
|
||||
import re
|
||||
import sox
|
||||
import zipfile
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import zipfile
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
SAMPLE_RATE = 16000
|
||||
@ -38,7 +30,7 @@ ARCHIVE_URL = 'https://lingualibre.fr/datasets/' + ARCHIVE_NAME
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -48,8 +40,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -62,12 +54,12 @@ def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
ogg_filename = sample[0]
|
||||
# Storing wav files next to the ogg ones - just with a different suffix
|
||||
wav_filename = path.splitext(ogg_filename)[0] + ".wav"
|
||||
wav_filename = os.path.splitext(ogg_filename)[0] + ".wav"
|
||||
_maybe_convert_wav(ogg_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = label_filter(sample[1])
|
||||
rows = []
|
||||
@ -94,7 +86,7 @@ def one_sample(sample):
|
||||
return (counter, rows)
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME + '_' + ARCHIVE_NAME.replace('.zip', '_{}.csv'))
|
||||
if os.path.isfile(target_csv_template):
|
||||
@ -160,7 +152,7 @@ def _maybe_convert_sets(target_dir, extracted_data):
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
def _maybe_convert_wav(ogg_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
|
@ -2,29 +2,19 @@
|
||||
# pylint: disable=invalid-name
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import subprocess
|
||||
import os
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import subprocess
|
||||
import tarfile
|
||||
import unicodedata
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from glob import glob
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
from multiprocessing import Pool
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
SAMPLE_RATE = 16000
|
||||
@ -37,7 +27,7 @@ ARCHIVE_URL = 'http://www.caito.de/data/Training/stt_tts/' + ARCHIVE_NAME
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -48,8 +38,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -65,8 +55,8 @@ def one_sample(sample):
|
||||
wav_filename = sample[0]
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
@ -93,7 +83,7 @@ def one_sample(sample):
|
||||
return (counter, rows)
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tgz', '_{}.csv'))
|
||||
if os.path.isfile(target_csv_template):
|
||||
|
@ -1,18 +1,13 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import os
|
||||
import pandas
|
||||
import tarfile
|
||||
import wave
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
|
@ -1,19 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import pandas
|
||||
import tarfile
|
||||
|
||||
from deepspeech_training.util.importers import get_importers_parser
|
||||
|
||||
COLUMN_NAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
|
@ -1,32 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import os
|
||||
import progressbar
|
||||
import re
|
||||
import sox
|
||||
import zipfile
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unicodedata
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import zipfile
|
||||
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from glob import glob
|
||||
from multiprocessing import Pool
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet
|
||||
from util.helpers import secs_to_hours
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
SAMPLE_RATE = 16000
|
||||
@ -39,7 +29,7 @@ ARCHIVE_URL = 'http://www.openslr.org/resources/57/' + ARCHIVE_NAME
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract data
|
||||
@ -49,8 +39,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -65,8 +55,8 @@ def one_sample(sample):
|
||||
wav_filename = sample[0]
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = label_filter(sample[1])
|
||||
counter = get_counter()
|
||||
@ -92,7 +82,7 @@ def one_sample(sample):
|
||||
return (counter, rows)
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, ARCHIVE_DIR_NAME, ARCHIVE_NAME.replace('.tar.gz', '_{}.csv'))
|
||||
if os.path.isfile(target_csv_template):
|
||||
|
@ -1,28 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
|
||||
# ensure that you have downloaded the LDC dataset LDC97S62 and tar exists in a folder e.g.
|
||||
# ./data/swb/swb1_LDC97S62.tgz
|
||||
# from the deepspeech directory run with: ./bin/import_swb.py ./data/swb/
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import fnmatch
|
||||
import librosa
|
||||
import os
|
||||
import pandas
|
||||
import requests
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import codecs
|
||||
import tarfile
|
||||
import requests
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
|
||||
|
||||
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03
|
||||
ARCHIVE_NAME = 'switchboard_word_alignments.tar.gz'
|
||||
|
@ -5,31 +5,26 @@ Use "python3 import_swc.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import re
|
||||
import csv
|
||||
import sox
|
||||
import wave
|
||||
import shutil
|
||||
import random
|
||||
import tarfile
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import progressbar
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import sox
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from glob import glob
|
||||
from collections import Counter
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
SWC_ARCHIVE = "SWC_{language}.tar"
|
||||
@ -117,8 +112,8 @@ def maybe_download_language(language):
|
||||
|
||||
|
||||
def maybe_extract(data_dir, extracted_data, archive):
|
||||
extracted = path.join(data_dir, extracted_data)
|
||||
if path.isdir(extracted):
|
||||
extracted = os.path.join(data_dir, extracted_data)
|
||||
if os.path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
@ -242,7 +237,7 @@ def collect_samples(base_dir, language):
|
||||
print('Collecting samples...')
|
||||
bar = progressbar.ProgressBar(max_value=len(roots), widgets=SIMPLE_BAR)
|
||||
for root in bar(roots):
|
||||
wav_path = path.join(root, WAV_NAME)
|
||||
wav_path = os.path.join(root, WAV_NAME)
|
||||
aligned = ET.parse(path.join(root, ALIGNED_NAME))
|
||||
article = UNKNOWN
|
||||
speaker = UNKNOWN
|
||||
@ -294,8 +289,8 @@ def maybe_convert_one_to_wav(entry):
|
||||
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
combiner = sox.Combiner()
|
||||
combiner.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS)
|
||||
output_wav = path.join(root, WAV_NAME)
|
||||
if path.isfile(output_wav):
|
||||
output_wav = os.path.join(root, WAV_NAME)
|
||||
if os.path.isfile(output_wav):
|
||||
return
|
||||
files = sorted(glob(path.join(root, AUDIO_PATTERN)))
|
||||
try:
|
||||
@ -304,7 +299,7 @@ def maybe_convert_one_to_wav(entry):
|
||||
elif len(files) > 1:
|
||||
wav_files = []
|
||||
for i, file in enumerate(files):
|
||||
wav_path = path.join(root, 'audio{}.wav'.format(i))
|
||||
wav_path = os.path.join(root, 'audio{}.wav'.format(i))
|
||||
transformer.build(file, wav_path)
|
||||
wav_files.append(wav_path)
|
||||
combiner.set_input_format(file_type=['wav'] * len(wav_files))
|
||||
@ -358,8 +353,8 @@ def assign_sub_sets(samples):
|
||||
def create_sample_dirs(language):
|
||||
print('Creating sample directories...')
|
||||
for set_name in ['train', 'dev', 'test']:
|
||||
dir_path = path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||
if not path.isdir(dir_path):
|
||||
dir_path = os.path.join(CLI_ARGS.base_dir, language + '-' + set_name)
|
||||
if not os.path.isdir(dir_path):
|
||||
os.mkdir(dir_path)
|
||||
|
||||
|
||||
@ -374,7 +369,7 @@ def split_audio_files(samples, language):
|
||||
rate = src_wav_file.getframerate()
|
||||
for sample in file_samples:
|
||||
index = sub_sets[sample.sub_set]
|
||||
sample_wav_path = path.join(CLI_ARGS.base_dir,
|
||||
sample_wav_path = os.path.join(CLI_ARGS.base_dir,
|
||||
language + '-' + sample.sub_set,
|
||||
'sample-{0:06d}.wav'.format(index))
|
||||
sample.wav_path = sample_wav_path
|
||||
@ -391,8 +386,8 @@ def split_audio_files(samples, language):
|
||||
def write_csvs(samples, language):
|
||||
for sub_set, set_samples in group(samples, lambda s: s.sub_set).items():
|
||||
set_samples = sorted(set_samples, key=lambda s: s.wav_path)
|
||||
base_dir = path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||
base_dir = os.path.abspath(CLI_ARGS.base_dir)
|
||||
csv_path = os.path.join(base_dir, language + '-' + sub_set + '.csv')
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES_EXT if CLI_ARGS.add_meta else FIELDNAMES)
|
||||
@ -400,8 +395,8 @@ def write_csvs(samples, language):
|
||||
bar = progressbar.ProgressBar(max_value=len(set_samples), widgets=SIMPLE_BAR)
|
||||
for sample in bar(set_samples):
|
||||
row = {
|
||||
'wav_filename': path.relpath(sample.wav_path, base_dir),
|
||||
'wav_filesize': path.getsize(sample.wav_path),
|
||||
'wav_filename': os.path.relpath(sample.wav_path, base_dir),
|
||||
'wav_filesize': os.path.getsize(sample.wav_path),
|
||||
'transcript': sample.text
|
||||
}
|
||||
if CLI_ARGS.add_meta:
|
||||
@ -414,8 +409,8 @@ def cleanup(archive, language):
|
||||
if not CLI_ARGS.keep_archive:
|
||||
print('Removing archive "{}"...'.format(archive))
|
||||
os.remove(archive)
|
||||
language_dir = path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and path.isdir(language_dir):
|
||||
language_dir = os.path.join(CLI_ARGS.base_dir, language)
|
||||
if not CLI_ARGS.keep_intermediate and os.path.isdir(language_dir):
|
||||
print('Removing intermediate files in "{}"...'.format(language_dir))
|
||||
shutil.rmtree(language_dir)
|
||||
|
||||
|
@ -1,14 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import codecs
|
||||
import pandas
|
||||
import sys
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
@ -16,9 +10,10 @@ import wave
|
||||
from glob import glob
|
||||
from os import makedirs, path, remove, rmdir
|
||||
from sox import Transformer
|
||||
from util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from tensorflow.python.platform import gfile
|
||||
from util.stm import parse_stm_file
|
||||
from deepspeech_training.util.stm import parse_stm_file
|
||||
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Conditionally download data
|
||||
|
@ -1,28 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import csv
|
||||
import unidecode
|
||||
import zipfile
|
||||
import os
|
||||
import progressbar
|
||||
import re
|
||||
import sox
|
||||
import subprocess
|
||||
import progressbar
|
||||
import unidecode
|
||||
import zipfile
|
||||
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from deepspeech_training.util.downloader import SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import get_importers_parser, get_validate_label, get_counter, get_imported_samples, print_import_report
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import SIMPLE_BAR
|
||||
|
||||
from os import path
|
||||
|
||||
from util.downloader import maybe_download
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
SAMPLE_RATE = 16000
|
||||
@ -34,7 +26,7 @@ ARCHIVE_URL = 'https://deepspeech-storage-mirror.s3.fr-par.scw.cloud/' + ARCHIVE
|
||||
|
||||
def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download('ts_' + ARCHIVE_NAME + '.zip', target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract archive data
|
||||
@ -45,8 +37,8 @@ def _download_and_preprocess_data(target_dir, english_compatible=False):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print('No directory "%s" - extracting archive...' % extracted_path)
|
||||
if not os.path.isdir(extracted_path):
|
||||
os.mkdir(extracted_path)
|
||||
@ -60,12 +52,12 @@ def one_sample(sample):
|
||||
""" Take a audio file, and optionally convert it to 16kHz WAV """
|
||||
orig_filename = sample['path']
|
||||
# Storing wav files next to the wav ones - just with a different suffix
|
||||
wav_filename = path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
wav_filename = os.path.splitext(orig_filename)[0] + ".converted.wav"
|
||||
_maybe_convert_wav(orig_filename, wav_filename)
|
||||
file_size = -1
|
||||
frames = 0
|
||||
if path.exists(wav_filename):
|
||||
file_size = path.getsize(wav_filename)
|
||||
if os.path.exists(wav_filename):
|
||||
file_size = os.path.getsize(wav_filename)
|
||||
frames = int(subprocess.check_output(['soxi', '-s', wav_filename], stderr=subprocess.STDOUT))
|
||||
label = sample['text']
|
||||
|
||||
@ -95,7 +87,7 @@ def one_sample(sample):
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
extracted_dir = path.join(target_dir, extracted_data)
|
||||
extracted_dir = os.path.join(target_dir, extracted_data)
|
||||
# override existing CSV with normalized one
|
||||
target_csv_template = os.path.join(target_dir, 'ts_' + ARCHIVE_NAME + '_{}.csv')
|
||||
if os.path.isfile(target_csv_template):
|
||||
@ -160,7 +152,7 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
|
||||
print_import_report(counter, SAMPLE_RATE, MAX_SECS)
|
||||
|
||||
def _maybe_convert_wav(orig_filename, wav_filename):
|
||||
if not path.exists(wav_filename):
|
||||
if not os.path.exists(wav_filename):
|
||||
transformer = sox.Transformer()
|
||||
transformer.convert(samplerate=SAMPLE_RATE)
|
||||
try:
|
||||
|
@ -5,25 +5,19 @@ Use "python3 import_tuda.py -h" for help
|
||||
'''
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import csv
|
||||
import wave
|
||||
import tarfile
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import progressbar
|
||||
import tarfile
|
||||
import unicodedata
|
||||
import wave
|
||||
import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from collections import Counter
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import validate_label_eng as validate_label
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
TUDA_VERSION = 'v2'
|
||||
TUDA_PACKAGE = 'german-speechdata-package-{}'.format(TUDA_VERSION)
|
||||
@ -38,8 +32,8 @@ FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
|
||||
def maybe_extract(archive):
|
||||
extracted = path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||
if path.isdir(extracted):
|
||||
extracted = os.path.join(CLI_ARGS.base_dir, TUDA_PACKAGE)
|
||||
if os.path.isdir(extracted):
|
||||
print('Found directory "{}" - not extracting.'.format(extracted))
|
||||
else:
|
||||
print('Extracting "{}"...'.format(archive))
|
||||
@ -92,7 +86,7 @@ def write_csvs(extracted):
|
||||
sample_counter = 0
|
||||
reasons = Counter()
|
||||
for sub_set in ['train', 'dev', 'test']:
|
||||
set_path = path.join(extracted, sub_set)
|
||||
set_path = os.path.join(extracted, sub_set)
|
||||
set_files = os.listdir(set_path)
|
||||
recordings = {}
|
||||
for file in set_files:
|
||||
@ -104,15 +98,15 @@ def write_csvs(extracted):
|
||||
if prefix in recordings:
|
||||
recordings[prefix].append(file)
|
||||
recordings = recordings.items()
|
||||
csv_path = path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
||||
csv_path = os.path.join(CLI_ARGS.base_dir, 'tuda-{}-{}.csv'.format(TUDA_VERSION, sub_set))
|
||||
print('Writing "{}"...'.format(csv_path))
|
||||
with open(csv_path, 'w') as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=FIELDNAMES)
|
||||
writer.writeheader()
|
||||
set_dir = path.join(extracted, sub_set)
|
||||
set_dir = os.path.join(extracted, sub_set)
|
||||
bar = progressbar.ProgressBar(max_value=len(recordings), widgets=SIMPLE_BAR)
|
||||
for prefix, wav_names in bar(recordings):
|
||||
xml_path = path.join(set_dir, prefix + '.xml')
|
||||
xml_path = os.path.join(set_dir, prefix + '.xml')
|
||||
meta = ET.parse(xml_path).getroot()
|
||||
sentence = list(meta.iter('cleaned_sentence'))[0].text
|
||||
sentence = check_and_prepare_sentence(sentence)
|
||||
@ -120,12 +114,12 @@ def write_csvs(extracted):
|
||||
continue
|
||||
for wav_name in wav_names:
|
||||
sample_counter += 1
|
||||
wav_path = path.join(set_path, wav_name)
|
||||
wav_path = os.path.join(set_path, wav_name)
|
||||
keep, reason = check_wav_file(wav_path, sentence)
|
||||
if keep:
|
||||
writer.writerow({
|
||||
'wav_filename': path.relpath(wav_path, CLI_ARGS.base_dir),
|
||||
'wav_filesize': path.getsize(wav_path),
|
||||
'wav_filename': os.path.relpath(wav_path, CLI_ARGS.base_dir),
|
||||
'wav_filesize': os.path.getsize(wav_path),
|
||||
'transcript': sentence.lower()
|
||||
})
|
||||
else:
|
||||
|
@ -1,30 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# VCTK used in wavenet paper https://arxiv.org/pdf/1609.03499.pdf
|
||||
# Licenced under Open Data Commons Attribution License (ODC-By) v1.0.
|
||||
# as per https://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], ".."))
|
||||
|
||||
from util.importers import get_counter, get_imported_samples, print_import_report
|
||||
|
||||
import re
|
||||
import librosa
|
||||
import os
|
||||
import progressbar
|
||||
import random
|
||||
import re
|
||||
|
||||
from os import path
|
||||
from deepspeech_training.util.downloader import maybe_download, SIMPLE_BAR
|
||||
from deepspeech_training.util.importers import get_counter, get_imported_samples, print_import_report
|
||||
from multiprocessing import Pool
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
from zipfile import ZipFile
|
||||
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
MAX_SECS = 10
|
||||
MIN_SECS = 1
|
||||
@ -37,7 +28,7 @@ ARCHIVE_URL = (
|
||||
|
||||
def _download_and_preprocess_data(target_dir):
|
||||
# Making path absolute
|
||||
target_dir = path.abspath(target_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
# Conditionally download data
|
||||
archive_path = maybe_download(ARCHIVE_NAME, target_dir, ARCHIVE_URL)
|
||||
# Conditionally extract common voice data
|
||||
@ -48,8 +39,8 @@ def _download_and_preprocess_data(target_dir):
|
||||
|
||||
def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
# If target_dir/extracted_data does not exist, extract archive in target_dir
|
||||
extracted_path = path.join(target_dir, extracted_data)
|
||||
if not path.exists(extracted_path):
|
||||
extracted_path = os.path.join(target_dir, extracted_data)
|
||||
if not os.path.exists(extracted_path):
|
||||
print(f"No directory {extracted_path} - extracting archive...")
|
||||
with ZipFile(archive_path, "r") as zipobj:
|
||||
# Extract all the contents of zip file in current directory
|
||||
@ -59,8 +50,8 @@ def _maybe_extract(target_dir, extracted_data, archive_path):
|
||||
|
||||
|
||||
def _maybe_convert_sets(target_dir, extracted_data):
|
||||
extracted_dir = path.join(target_dir, extracted_data, "wav48")
|
||||
txt_dir = path.join(target_dir, extracted_data, "txt")
|
||||
extracted_dir = os.path.join(target_dir, extracted_data, "wav48")
|
||||
txt_dir = os.path.join(target_dir, extracted_data, "txt")
|
||||
|
||||
directory = os.path.expanduser(extracted_dir)
|
||||
srtd = len(sorted(os.listdir(directory)))
|
||||
|
@ -2,23 +2,20 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import tarfile
|
||||
import pandas
|
||||
import re
|
||||
import unicodedata
|
||||
import tarfile
|
||||
import threading
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import unicodedata
|
||||
|
||||
from six.moves import urllib
|
||||
from glob import glob
|
||||
from os import makedirs, path
|
||||
from bs4 import BeautifulSoup
|
||||
from deepspeech_training.util.downloader import maybe_download
|
||||
from glob import glob
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from os import makedirs, path
|
||||
from six.moves import urllib
|
||||
from tensorflow.python.platform import gfile
|
||||
from util.downloader import maybe_download
|
||||
|
||||
"""The number of jobs to run in parallel"""
|
||||
NUM_PARALLEL = 8
|
||||
@ -99,7 +96,7 @@ def _parallel_extracter(data_dir, number_of_test, number_of_dev, total, counter)
|
||||
dataset_dir = path.join(data_dir, "dev")
|
||||
else:
|
||||
dataset_dir = path.join(data_dir, "train")
|
||||
if not gfile.Exists(path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
||||
if not gfile.Exists(os.path.join(dataset_dir, '.'.join(filename_of(archive).split(".")[:-1]))):
|
||||
c = counter.increment()
|
||||
print('Extracting file {} ({}/{})...'.format(i+1, c, total))
|
||||
tar = tarfile.open(archive)
|
||||
@ -132,14 +129,14 @@ def _download_and_preprocess_data(data_dir):
|
||||
p.map(downloader, enumerate(refs))
|
||||
|
||||
# Conditionally extract data to dataset_dir
|
||||
if not path.isdir(path.join(data_dir,"test")):
|
||||
makedirs(path.join(data_dir,"test"))
|
||||
if not path.isdir(path.join(data_dir,"dev")):
|
||||
makedirs(path.join(data_dir,"dev"))
|
||||
if not path.isdir(path.join(data_dir,"train")):
|
||||
makedirs(path.join(data_dir,"train"))
|
||||
if not path.isdir(os.path.join(data_dir, "test")):
|
||||
makedirs(os.path.join(data_dir, "test"))
|
||||
if not path.isdir(os.path.join(data_dir, "dev")):
|
||||
makedirs(os.path.join(data_dir, "dev"))
|
||||
if not path.isdir(os.path.join(data_dir, "train")):
|
||||
makedirs(os.path.join(data_dir, "train"))
|
||||
|
||||
tarfiles = glob(path.join(archive_dir, "*.tgz"))
|
||||
tarfiles = glob(os.path.join(archive_dir, "*.tgz"))
|
||||
number_of_files = len(tarfiles)
|
||||
number_of_test = number_of_files//100
|
||||
number_of_dev = number_of_files//100
|
||||
@ -156,20 +153,20 @@ def _download_and_preprocess_data(data_dir):
|
||||
train_files = _generate_dataset(data_dir, "train")
|
||||
|
||||
# Write sets to disk as CSV files
|
||||
train_files.to_csv(path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||
dev_files.to_csv(path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
train_files.to_csv(os.path.join(data_dir, "voxforge-train.csv"), index=False)
|
||||
dev_files.to_csv(os.path.join(data_dir, "voxforge-dev.csv"), index=False)
|
||||
test_files.to_csv(os.path.join(data_dir, "voxforge-test.csv"), index=False)
|
||||
|
||||
def _generate_dataset(data_dir, data_set):
|
||||
extracted_dir = path.join(data_dir, data_set)
|
||||
files = []
|
||||
for promts_file in glob(path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
||||
if path.isdir(path.join(promts_file[:-11],"wav")):
|
||||
for promts_file in glob(os.path.join(extracted_dir+"/*/etc/", "PROMPTS")):
|
||||
if path.isdir(os.path.join(promts_file[:-11], "wav")):
|
||||
with codecs.open(promts_file, 'r', 'utf-8') as f:
|
||||
for line in f:
|
||||
id = line.split(' ')[0].split('/')[-1]
|
||||
sentence = ' '.join(line.split(' ')[1:])
|
||||
sentence = re.sub("[^a-z']"," ",sentence.strip().lower())
|
||||
sentence = re.sub("[^a-z']", " ",sentence.strip().lower())
|
||||
transcript = ""
|
||||
for token in sentence.split(" "):
|
||||
word = token.strip()
|
||||
@ -178,14 +175,14 @@ def _generate_dataset(data_dir, data_set):
|
||||
transcript = unicodedata.normalize("NFKD", transcript.strip()) \
|
||||
.encode("ascii", "ignore") \
|
||||
.decode("ascii", "ignore")
|
||||
wav_file = path.join(promts_file[:-11],"wav/" + id + ".wav")
|
||||
wav_file = path.join(promts_file[:-11], "wav/" + id + ".wav")
|
||||
if gfile.Exists(wav_file):
|
||||
wav_filesize = path.getsize(wav_file)
|
||||
# remove audios that are shorter than 0.5s and longer than 20s.
|
||||
# remove audios that are too short for transcript.
|
||||
if (wav_filesize/32000)>0.5 and (wav_filesize/32000)<20 and transcript!="" and \
|
||||
wav_filesize/len(transcript)>1400:
|
||||
files.append((path.abspath(wav_file), wav_filesize, transcript))
|
||||
if ((wav_filesize/32000) > 0.5 and (wav_filesize/32000) < 20 and transcript != "" and
|
||||
wav_filesize/len(transcript) > 1400):
|
||||
files.append((os.path.abspath(wav_file), wav_filesize, transcript))
|
||||
|
||||
return pandas.DataFrame(data=files, columns=["wav_filename", "wav_filesize", "transcript"])
|
||||
|
||||
|
13
bin/play.py
13
bin/play.py
@ -5,17 +5,12 @@ Use "python3 build_sdb.py -h" for help
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
import random
|
||||
import argparse
|
||||
import random
|
||||
import sys
|
||||
|
||||
from util.sample_collections import samples_from_file, LabeledSample
|
||||
from util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.audio import AUDIO_TYPE_PCM
|
||||
from deepspeech_training.util.sample_collections import samples_from_file, LabeledSample
|
||||
|
||||
|
||||
def play_sample(samples, index):
|
||||
|
@ -1,17 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ In creating a virtual environment you will create a directory containing a ``pyt
|
||||
|
||||
.. code-block::
|
||||
|
||||
$ virtualenv -p python3 $HOME/tmp/deepspeech-train-venv/
|
||||
$ python3 -m venv $HOME/tmp/deepspeech-train-venv/
|
||||
|
||||
Once this command completes successfully, the environment will be ready to be activated.
|
||||
|
||||
@ -46,7 +46,7 @@ Install the required dependencies using ``pip3``\ :
|
||||
.. code-block:: bash
|
||||
|
||||
cd DeepSpeech
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install -e .
|
||||
|
||||
The ``webrtcvad`` Python package might require you to ensure you have proper tooling to build Python modules:
|
||||
|
||||
@ -70,7 +70,7 @@ If you have a capable (NVIDIA, at least 8GB of VRAM) GPU, it is highly recommend
|
||||
.. code-block:: bash
|
||||
|
||||
pip3 uninstall tensorflow
|
||||
pip3 install 'tensorflow-gpu==1.15.0'
|
||||
pip3 install 'tensorflow-gpu==1.15.2'
|
||||
|
||||
Please ensure you have the required `CUDA dependency <USING.rst#cuda-dependency>`_.
|
||||
|
||||
|
158
evaluate.py
Executable file → Normal file
158
evaluate.py
Executable file → Normal file
@ -2,155 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.feeding import create_dataset
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version
|
||||
from util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values, converting tokens to strings using ``alphabet``.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [[] for _ in range(sp_tuple[2][0])]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x,
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
seq_length=batch_x_len,
|
||||
dropout=no_dropout)
|
||||
|
||||
# Transpose to batch major and apply softmax for decoder
|
||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||
|
||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
bar = create_progressbar(prefix='Test epoch | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Test epoch...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
bar.update(step_count)
|
||||
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
print('Testing model on {}'.format(csv))
|
||||
samples.extend(run_test(init_op, dataset=csv))
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
try:
|
||||
from deepspeech_training import evaluate as ds_evaluate
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
ds_evaluate.run_script()
|
||||
|
@ -10,13 +10,12 @@ import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
from functools import partial
|
||||
from six.moves import zip, range
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from deepspeech import Model
|
||||
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.flags import create_flags
|
||||
from deepspeech_training.util.evaluate_tools import calculate_and_print_report
|
||||
from deepspeech_training.util.flags import create_flags
|
||||
from functools import partial
|
||||
from multiprocessing import JoinableQueue, Process, cpu_count, Manager
|
||||
from six.moves import zip, range
|
||||
|
||||
r'''
|
||||
This module should be self-contained:
|
||||
|
@ -2,19 +2,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, print_function
|
||||
|
||||
import sys
|
||||
|
||||
import optuna
|
||||
import absl.app
|
||||
from ds_ctcdecoder import Scorer
|
||||
import optuna
|
||||
import sys
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from DeepSpeech import create_model
|
||||
from evaluate import evaluate
|
||||
from util.config import Config, initialize_globals
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error
|
||||
from util.evaluate_tools import wer_cer_batch
|
||||
from deepspeech_training.evaluate import evaluate
|
||||
from deepspeech_training.train import create_model
|
||||
from deepspeech_training.util.config import Config, initialize_globals
|
||||
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||
from deepspeech_training.util.logging import log_error
|
||||
from deepspeech_training.util.evaluate_tools import wer_cer_batch
|
||||
from ds_ctcdecoder import Scorer
|
||||
|
||||
|
||||
def character_based():
|
||||
|
@ -1,24 +0,0 @@
|
||||
# Main training requirements
|
||||
tensorflow == 1.15.2
|
||||
numpy == 1.18.1
|
||||
progressbar2
|
||||
six
|
||||
pyxdg
|
||||
attrdict
|
||||
absl-py
|
||||
semver
|
||||
opuslib == 2.0.0
|
||||
|
||||
# Requirements for building native_client files
|
||||
setuptools
|
||||
|
||||
# Requirements for importers
|
||||
sox
|
||||
bs4
|
||||
pandas
|
||||
requests
|
||||
librosa
|
||||
soundfile
|
||||
|
||||
# Requirements for optimizer
|
||||
optuna
|
53
setup.py
Normal file
53
setup.py
Normal file
@ -0,0 +1,53 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
def main():
|
||||
setup(
|
||||
name='deepspeech_training',
|
||||
version='0.0.1',
|
||||
description='Training code for mozilla DeepSpeech',
|
||||
url='https://github.com/mozilla/DeepSpeech',
|
||||
author='Mozilla',
|
||||
license='MPL-2.0',
|
||||
# Classifiers help users find your project by categorizing it.
|
||||
#
|
||||
# For a list of valid classifiers, see https://pypi.org/classifiers/
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Multimedia :: Sound/Audio :: Speech',
|
||||
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
|
||||
'Programming Language :: Python :: 3',
|
||||
],
|
||||
package_dir={'': 'training'},
|
||||
packages=find_packages(where='training'),
|
||||
python_requires='>=3.5, <4',
|
||||
install_requires=[
|
||||
'tensorflow == 1.15.2',
|
||||
'numpy == 1.18.1',
|
||||
'progressbar2',
|
||||
'six',
|
||||
'pyxdg',
|
||||
'attrdict',
|
||||
'absl-py',
|
||||
'semver',
|
||||
'opuslib == 2.0.0',
|
||||
'optuna',
|
||||
'sox',
|
||||
'bs4',
|
||||
'pandas',
|
||||
'requests',
|
||||
'librosa',
|
||||
'soundfile',
|
||||
],
|
||||
# If there are data files included in your packages that need to be
|
||||
# installed, specify them here.
|
||||
package_data={
|
||||
'deepspeech_training': [
|
||||
'VERSION',
|
||||
'GRAPH_VERSION',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
43
stats.py
43
stats.py
@ -1,10 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import functools
|
||||
import pandas
|
||||
|
||||
from deepspeech_training.util.helpers import secs_to_hours
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_csvs(csv_files):
|
||||
# Relative paths are relative to CSV location
|
||||
def absolutify(csv, path):
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
return str(path)
|
||||
return str(csv.parent / path)
|
||||
|
||||
sets = []
|
||||
for csv in csv_files:
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv))
|
||||
sets.append(file)
|
||||
|
||||
# Concat all sets, drop any extra columns, re-index the final result as 0..N
|
||||
return pandas.concat(sets, join='inner', ignore_index=True)
|
||||
|
||||
from util.helpers import secs_to_hours
|
||||
from util.feeding import read_csvs
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -14,20 +33,16 @@ def main():
|
||||
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels")
|
||||
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample")
|
||||
args = parser.parse_args()
|
||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
|
||||
|
||||
csv_dataframe = read_csvs(in_files)
|
||||
total_bytes = csv_dataframe['wav_filesize'].sum()
|
||||
total_files = len(csv_dataframe.index)
|
||||
total_files = len(csv_dataframe)
|
||||
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum()
|
||||
|
||||
bytes_without_headers = total_bytes - 44 * total_files
|
||||
|
||||
total_time = bytes_without_headers / (args.sample_rate * args.channels * args.bits_per_sample / 8)
|
||||
|
||||
print('total_bytes', total_bytes)
|
||||
print('total_files', total_files)
|
||||
print('bytes_without_headers', bytes_without_headers)
|
||||
print('total_time', secs_to_hours(total_time))
|
||||
print('Total bytes:', total_bytes)
|
||||
print('Total files:', total_files)
|
||||
print('Total time:', secs_to_hours(total_seconds))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -17,7 +17,9 @@ deepspeech_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type})
|
||||
set -o pipefail
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: --upgrade ${deepspeech_pkg_url} | cat
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
which deepspeech
|
||||
|
@ -17,7 +17,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
decoder_pkg_url=$(get_python_pkg_url ${pyver_pkg} ${py_unicode_type} "ds_ctcdecoder" "${DECODER_ARTIFACTS_ROOT}")
|
||||
|
@ -16,7 +16,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
|
@ -14,7 +14,9 @@ virtualenv_activate "${pyalias}" "deepspeech"
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
pushd ${HOME}/DeepSpeech/ds
|
||||
pip install --upgrade . | cat
|
||||
popd
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
|
@ -1,10 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from argparse import Namespace
|
||||
from .importers import validate_label_eng, get_validate_label
|
||||
from deepspeech_training.util.importers import validate_label_eng, get_validate_label
|
||||
from pathlib import Path
|
||||
|
||||
def from_here(path):
|
||||
here = Path(__file__)
|
||||
return here.parent / path
|
||||
|
||||
class TestValidateLabelEng(unittest.TestCase):
|
||||
|
||||
def test_numbers(self):
|
||||
label = validate_label_eng("this is a 1 2 3 test")
|
||||
self.assertEqual(label, None)
|
||||
@ -24,12 +28,12 @@ class TestGetValidateLabel(unittest.TestCase):
|
||||
self.assertEqual(f('toto1234[{[{[]'), None)
|
||||
|
||||
def test_get_validate_label_missing(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
|
||||
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py'))
|
||||
f = get_validate_label(args)
|
||||
self.assertEqual(f, None)
|
||||
|
||||
def test_get_validate_label(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
|
||||
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py'))
|
||||
f = get_validate_label(args)
|
||||
l = f('toto')
|
||||
self.assertEqual(l, 'toto')
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from .text import Alphabet
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
|
||||
class TestAlphabetParsing(unittest.TestCase):
|
||||
|
1
training/deepspeech_training/GRAPH_VERSION
Symbolic link
1
training/deepspeech_training/GRAPH_VERSION
Symbolic link
@ -0,0 +1 @@
|
||||
../../GRAPH_VERSION
|
1
training/deepspeech_training/VERSION
Symbolic link
1
training/deepspeech_training/VERSION
Symbolic link
@ -0,0 +1 @@
|
||||
../../VERSION
|
0
training/deepspeech_training/__init__.py
Normal file
0
training/deepspeech_training/__init__.py
Normal file
159
training/deepspeech_training/evaluate.py
Executable file
159
training/deepspeech_training/evaluate.py
Executable file
@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values, converting tokens to strings using ``alphabet``.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [[] for _ in range(sp_tuple[2][0])]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
test_csvs = FLAGS.test_files.split(',')
|
||||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
|
||||
tfv1.data.get_output_shapes(test_sets[0]),
|
||||
output_classes=tfv1.data.get_output_classes(test_sets[0]))
|
||||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
|
||||
|
||||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x,
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
seq_length=batch_x_len,
|
||||
dropout=no_dropout)
|
||||
|
||||
# Transpose to batch major and apply softmax for decoder
|
||||
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
|
||||
|
||||
loss = tfv1.nn.ctc_loss(labels=batch_y,
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
tfv1.train.get_or_create_global_step()
|
||||
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
losses = []
|
||||
predictions = []
|
||||
ground_truths = []
|
||||
|
||||
bar = create_progressbar(prefix='Test epoch | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Test epoch...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# First pass, compute losses and transposed logits for decoding
|
||||
while True:
|
||||
try:
|
||||
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \
|
||||
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes, scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet))
|
||||
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames)
|
||||
losses.extend(batch_loss)
|
||||
|
||||
step_count += 1
|
||||
bar.update(step_count)
|
||||
|
||||
bar.finish()
|
||||
|
||||
# Print test summary
|
||||
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset)
|
||||
return test_samples
|
||||
|
||||
samples = []
|
||||
for csv, init_op in zip(test_csvs, test_init_ops):
|
||||
print('Testing model on {}'.format(csv))
|
||||
samples.extend(run_test(init_op, dataset=csv))
|
||||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
log_error('You need to specify what files to use for evaluation via '
|
||||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_script()
|
936
training/deepspeech_training/train.py
Normal file
936
training/deepspeech_training/train.py
Normal file
@ -0,0 +1,936 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
from .util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
output, output_state = fw_cell(inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode='linear_input',
|
||||
direction='unidirectional',
|
||||
dtype=tf.float32)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||
sequence_lengths=seq_length)
|
||||
|
||||
return output, output_state
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name='cudnn_compatible_lstm_cell')
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope='cell_0')
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers['input_reshaped'] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0])
|
||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1])
|
||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2])
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers['rnn_output'] = output
|
||||
layers['rnn_output_state'] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5])
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||
layers['raw_logits'] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
|
||||
# In accord with 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# the loss function used by our network should be the CTC loss function
|
||||
# (http://www.cs.toronto.edu/~graves/preprint.pdf).
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
the decoded result and the batch's original Y.
|
||||
'''
|
||||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
# Calculate the logits of the batch
|
||||
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)
|
||||
|
||||
# Compute the CTC loss using TensorFlow's `ctc_loss`
|
||||
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
|
||||
|
||||
# Check if any files lead to non finite loss
|
||||
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
|
||||
|
||||
# Calculate the average loss across the batch
|
||||
avg_loss = tf.reduce_mean(input_tensor=total_loss)
|
||||
|
||||
# Finally we return the average loss
|
||||
return avg_loss, non_finite_files
|
||||
|
||||
|
||||
# Adam Optimization
|
||||
# =================
|
||||
|
||||
# In contrast to 'Deep Speech: Scaling up end-to-end speech recognition'
|
||||
# (http://arxiv.org/abs/1412.5567),
|
||||
# in which 'Nesterov's Accelerated Gradient Descent'
|
||||
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
|
||||
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
|
||||
# because, generally, it requires less fine-tuning.
|
||||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon)
|
||||
return optimizer
|
||||
|
||||
|
||||
# Towers
|
||||
# ======
|
||||
|
||||
# In order to properly make use of multiple GPU's, one must introduce new abstractions,
|
||||
# not present when using a single GPU, that facilitate the multi-GPU use case.
|
||||
# In particular, one must introduce a means to isolate the inference and gradient
|
||||
# calculations on the various GPU's.
|
||||
# The abstraction we intoduce for this purpose is called a 'tower'.
|
||||
# A tower is specified by two properties:
|
||||
# * **Scope** - A scope, as provided by `tf.name_scope()`,
|
||||
# is a means to isolate the operations within a tower.
|
||||
# For example, all operations within 'tower 0' could have their name prefixed with `tower_0/`.
|
||||
# * **Device** - A hardware device, as provided by `tf.device()`,
|
||||
# on which all operations within the tower execute.
|
||||
# For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`.
|
||||
|
||||
def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
r'''
|
||||
With this preliminary step out of the way, we can for each GPU introduce a
|
||||
tower for which's batch we calculate and return the optimization gradients
|
||||
and the average loss across towers.
|
||||
'''
|
||||
# To calculate the mean of the losses
|
||||
tower_avg_losses = []
|
||||
|
||||
# Tower gradients to return
|
||||
tower_gradients = []
|
||||
|
||||
# Aggregate any non finite files in the batches
|
||||
tower_non_finite_files = []
|
||||
|
||||
with tfv1.variable_scope(tfv1.get_variable_scope()):
|
||||
# Loop over available_devices
|
||||
for i in range(len(Config.available_devices)):
|
||||
# Execute operations of tower i on device i
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tfv1.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain tower's avg losses
|
||||
tower_avg_losses.append(avg_loss)
|
||||
|
||||
# Compute gradients for model parameters using tower's mini-batch
|
||||
gradients = optimizer.compute_gradients(avg_loss)
|
||||
|
||||
# Retain tower's gradients
|
||||
tower_gradients.append(gradients)
|
||||
|
||||
tower_non_finite_files.append(non_finite_files)
|
||||
|
||||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
|
||||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])
|
||||
|
||||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
|
||||
|
||||
# Return gradients and the average loss
|
||||
return tower_gradients, avg_loss_across_towers, all_non_finite_files
|
||||
|
||||
|
||||
def average_gradients(tower_gradients):
|
||||
r'''
|
||||
A routine for computing each variable's average of the gradients obtained from the GPUs.
|
||||
Note also that this code acts as a synchronization point as it requires all
|
||||
GPUs to be finished with their mini-batch before it can run to completion.
|
||||
'''
|
||||
# List of average gradients to return to the caller
|
||||
average_grads = []
|
||||
|
||||
# Run this on cpu_device to conserve GPU memory
|
||||
with tf.device(Config.cpu_device):
|
||||
# Loop over gradient/variable pairs from all towers
|
||||
for grad_and_vars in zip(*tower_gradients):
|
||||
# Introduce grads to store the gradients for the current variable
|
||||
grads = []
|
||||
|
||||
# Loop over the gradients for the current variable
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension
|
||||
grad = tf.concat(grads, 0)
|
||||
grad = tf.reduce_mean(input_tensor=grad, axis=0)
|
||||
|
||||
# Create a gradient/variable tuple for the current variable with its average gradient
|
||||
grad_and_var = (grad, grad_and_vars[0][1])
|
||||
|
||||
# Add the current tuple to average_grads
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
# Return result to caller
|
||||
return average_grads
|
||||
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r'''
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
'''
|
||||
name = variable.name.replace(':', '_')
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name='%s/mean' % name, tensor=mean)
|
||||
tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))))
|
||||
tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable))
|
||||
tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable))
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r'''
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
'''
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (FLAGS.data_aug_features_multiplicative > 0 or
|
||||
FLAGS.data_aug_features_additive > 0 or
|
||||
FLAGS.augmentation_spec_dropout_keeprate < 1 or
|
||||
FLAGS.augmentation_freq_and_time_masking or
|
||||
FLAGS.augmentation_pitch_and_tempo_scaling or
|
||||
FLAGS.augmentation_speed_up_std > 0 or
|
||||
FLAGS.augmentation_sparse_warp):
|
||||
do_cache_dataset = False
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(FLAGS.train_files.split(','),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
enable_cache=FLAGS.feature_cache and do_cache_dataset,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
output_classes=tfv1.data.get_output_classes(train_set))
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(',')
|
||||
dev_sets = [create_dataset([source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
train_phase=False,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
buffering=FLAGS.read_buffer) for source in dev_sources]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {
|
||||
rate: 0. for rate in dropout_rates
|
||||
}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
log_info('Enabling automatic mixed precision training.')
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
|
||||
|
||||
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
|
||||
|
||||
# Summaries
|
||||
step_summaries_op = tfv1.summary.merge_all('step_summaries')
|
||||
step_summary_writers = {
|
||||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
|
||||
'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120)
|
||||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
train_op = apply_gradient_op if is_train else []
|
||||
feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict
|
||||
|
||||
total_loss = 0.0
|
||||
step_count = 0
|
||||
|
||||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
# Setup progress bar
|
||||
class LossWidget(progressbar.widgets.FormatLabel):
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
prefix = 'Epoch {} | {:>10}'.format(epoch, 'Training' if is_train else 'Validation')
|
||||
widgets = [' | ', progressbar.widgets.Timer(),
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
||||
# Batch loop
|
||||
while True:
|
||||
try:
|
||||
_, current_step, batch_loss, problem_files, step_summary = \
|
||||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
|
||||
feed_dict=feed_dict)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.InvalidArgumentError as err:
|
||||
if FLAGS.augmentation_sparse_warp:
|
||||
log_info("Ignoring sparse warp error: {}".format(err))
|
||||
continue
|
||||
raise
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
problem_files = [f.decode('utf8') for f in problem_files[..., 0]]
|
||||
log_error('The following files caused an infinite (or NaN) '
|
||||
'loss: {}'.format(','.join(problem_files)))
|
||||
|
||||
total_loss += batch_loss
|
||||
step_count += 1
|
||||
|
||||
pbar.update(step_count)
|
||||
|
||||
step_summary_writer.add_summary(step_summary, current_step)
|
||||
|
||||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
pbar.finish()
|
||||
mean_loss = total_loss / step_count if step_count > 0 else 0.0
|
||||
return mean_loss, step_count
|
||||
|
||||
log_info('STARTING Optimization')
|
||||
train_start_time = datetime.utcnow()
|
||||
best_dev_loss = float('inf')
|
||||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
# Training
|
||||
log_progress('Training epoch %d...' % epoch)
|
||||
train_loss, _ = run_set('train', epoch, train_init_op)
|
||||
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
log_progress('Validating epoch %d on %s...' % (epoch, source))
|
||||
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
total_steps += steps
|
||||
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
|
||||
|
||||
dev_loss = dev_loss / total_steps
|
||||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
||||
# Save new best model
|
||||
if dev_loss < best_dev_loss:
|
||||
best_dev_loss = dev_loss
|
||||
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
|
||||
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
|
||||
|
||||
# Early stopping
|
||||
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
|
||||
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
|
||||
epochs_without_improvement))
|
||||
break
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
if (FLAGS.reduce_lr_on_plateau and
|
||||
epochs_without_improvement % FLAGS.plateau_epochs == 0 and epochs_without_improvement > 0):
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
session.run(reduce_learning_rate_op)
|
||||
current_learning_rate = learning_rate_var.eval()
|
||||
log_info('Encountered a plateau, reducing learning rate to {}'.format(
|
||||
current_learning_rate))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
|
||||
log_debug('Session closed.')
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in DS_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
inputs['input_lengths'] = seq_length
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r'''
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
'''
|
||||
log_info('Exporting the model...')
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||
FLAGS.export_author_id,
|
||||
FLAGS.export_model_name,
|
||||
FLAGS.export_model_version))
|
||||
|
||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||
with open(metadata_fname, 'w') as f:
|
||||
f.write('---\n')
|
||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||
f.write('runtime: {}\n'.format(model_runtime))
|
||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||
f.write('---\n')
|
||||
f.write('{}\n'.format(FLAGS.export_description))
|
||||
|
||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
|
||||
if os.listdir(FLAGS.export_dir):
|
||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_script()
|
0
training/deepspeech_training/util/__init__.py
Normal file
0
training/deepspeech_training/util/__init__.py
Normal file
@ -5,7 +5,7 @@ import tempfile
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from util.helpers import LimitingPool
|
||||
from .helpers import LimitingPool
|
||||
|
||||
DEFAULT_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
@ -2,8 +2,8 @@ import sys
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.logging import log_info, log_error, log_warn
|
||||
from .flags import FLAGS
|
||||
from .logging import log_info, log_error, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path):
|
@ -8,11 +8,11 @@ import tensorflow.compat.v1 as tfv1
|
||||
from attrdict import AttrDict
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from util.helpers import parse_file_size
|
||||
from .flags import FLAGS
|
||||
from .gpu import get_available_gpus
|
||||
from .logging import log_error
|
||||
from .text import Alphabet, UTF8Alphabet
|
||||
from .helpers import parse_file_size
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
@ -7,8 +7,8 @@ import numpy as np
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.text import levenshtein
|
||||
from .flags import FLAGS
|
||||
from .text import levenshtein
|
||||
|
||||
|
||||
def pmap(fun, iterable):
|
@ -8,13 +8,13 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||
|
||||
from util.config import Config
|
||||
from util.text import text_to_char_array
|
||||
from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from util.audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
||||
from util.sample_collections import samples_from_files
|
||||
from util.helpers import remember_exception, MEGABYTE
|
||||
from .config import Config
|
||||
from .text import text_to_char_array
|
||||
from .flags import FLAGS
|
||||
from .spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up, augment_sparse_warp
|
||||
from .audio import change_audio_types, read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT, AUDIO_TYPE_NP
|
||||
from .sample_collections import samples_from_files
|
||||
from .helpers import remember_exception, MEGABYTE
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False, sample_id=None):
|
@ -4,7 +4,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from util.helpers import secs_to_hours
|
||||
from .helpers import secs_to_hours
|
||||
from collections import Counter
|
||||
|
||||
def get_counter():
|
@ -3,7 +3,7 @@ from __future__ import print_function
|
||||
import progressbar
|
||||
import sys
|
||||
|
||||
from util.flags import FLAGS
|
||||
from .flags import FLAGS
|
||||
|
||||
|
||||
# Logging functions
|
@ -5,8 +5,8 @@ import json
|
||||
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from util.helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from util.audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
||||
from .helpers import MEGABYTE, GIGABYTE, Interleaved
|
||||
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, SERIALIZABLE_AUDIO_TYPES
|
||||
|
||||
BIG_ENDIAN = 'big'
|
||||
INT_SIZE = 4
|
@ -1,6 +1,7 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from util.sparse_image_warp import sparse_image_warp
|
||||
|
||||
from .sparse_image_warp import sparse_image_warp
|
||||
|
||||
def augment_freq_time_mask(spectrogram,
|
||||
frequency_masking_para=30,
|
167
training/deepspeech_training/util/taskcluster.py
Normal file
167
training/deepspeech_training/util/taskcluster.py
Normal file
@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function, absolute_import, division
|
||||
|
||||
import argparse
|
||||
import errno
|
||||
import gzip
|
||||
import os
|
||||
import platform
|
||||
import six.moves.urllib as urllib
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from pkg_resources import parse_version
|
||||
|
||||
|
||||
DEFAULT_SCHEMES = {
|
||||
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
|
||||
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
|
||||
}
|
||||
|
||||
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
|
||||
|
||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||
assert arch_string is not None
|
||||
assert artifact_name is not None
|
||||
assert artifact_name
|
||||
assert branch_name is not None
|
||||
assert branch_name
|
||||
|
||||
return TASKCLUSTER_SCHEME % {'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||
|
||||
def maybe_download_tc(target_dir, tc_url, progress=True):
|
||||
def report_progress(count, block_size, total_size):
|
||||
percent = (count * block_size * 100) // total_size
|
||||
sys.stdout.write("\rDownloading: %d%%" % percent)
|
||||
sys.stdout.flush()
|
||||
|
||||
if percent >= 100:
|
||||
print('\n')
|
||||
|
||||
assert target_dir is not None
|
||||
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
try:
|
||||
os.makedirs(target_dir)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise e
|
||||
assert os.path.isdir(os.path.dirname(target_dir))
|
||||
|
||||
tc_filename = os.path.basename(tc_url)
|
||||
target_file = os.path.join(target_dir, tc_filename)
|
||||
is_gzip = False
|
||||
if not os.path.isfile(target_file):
|
||||
print('Downloading %s ...' % tc_url)
|
||||
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
|
||||
is_gzip = headers.get('Content-Encoding') == 'gzip'
|
||||
else:
|
||||
print('File already exists: %s' % target_file)
|
||||
|
||||
if is_gzip:
|
||||
with open(target_file, "r+b") as frw:
|
||||
decompressed = gzip.decompress(frw.read())
|
||||
frw.seek(0)
|
||||
frw.write(decompressed)
|
||||
frw.truncate()
|
||||
|
||||
return target_file
|
||||
|
||||
def maybe_download_tc_bin(**kwargs):
|
||||
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
|
||||
final_stat = os.stat(final_file)
|
||||
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='Where to put the native client binary files')
|
||||
parser.add_argument('--arch', required=False,
|
||||
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
|
||||
parser.add_argument('--artifact', required=False,
|
||||
default='native_client.tar.xz',
|
||||
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
|
||||
parser.add_argument('--source', required=False, default=None,
|
||||
help='Name of the TaskCluster scheme to use.')
|
||||
parser.add_argument('--branch', required=False,
|
||||
help='Branch name to use. Defaulting to current content of VERSION file.')
|
||||
parser.add_argument('--decoder', action='store_true',
|
||||
help='Get URL to ds_ctcdecoder Python package.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.target and not args.decoder:
|
||||
print('Pass either --target or --decoder.')
|
||||
sys.exit(1)
|
||||
|
||||
is_arm = 'arm' in platform.machine()
|
||||
is_mac = 'darwin' in sys.platform
|
||||
is_64bit = sys.maxsize > (2**31 - 1)
|
||||
is_ucs2 = sys.maxunicode < 0x10ffff
|
||||
|
||||
if not args.arch:
|
||||
if is_arm:
|
||||
args.arch = 'arm64' if is_64bit else 'arm'
|
||||
elif is_mac:
|
||||
args.arch = 'osx'
|
||||
else:
|
||||
args.arch = 'cpu'
|
||||
|
||||
if not args.branch:
|
||||
version_string = read('../VERSION').strip()
|
||||
ds_version = parse_version(version_string)
|
||||
args.branch = "v{}".format(version_string)
|
||||
else:
|
||||
ds_version = parse_version(args.branch)
|
||||
|
||||
if args.decoder:
|
||||
plat = platform.system().lower()
|
||||
arch = platform.machine()
|
||||
|
||||
if plat == 'linux' and arch == 'x86_64':
|
||||
plat = 'manylinux1'
|
||||
|
||||
if plat == 'darwin':
|
||||
plat = 'macosx_10_10'
|
||||
|
||||
m_or_mu = 'mu' if is_ucs2 else 'm'
|
||||
pyver = ''.join(map(str, sys.version_info[0:2]))
|
||||
|
||||
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
|
||||
ds_version=ds_version,
|
||||
pyver=pyver,
|
||||
m_or_mu=m_or_mu,
|
||||
platform=plat,
|
||||
arch=arch
|
||||
)
|
||||
|
||||
ctc_arch = args.arch + '-ctc'
|
||||
|
||||
print(get_tc_url(ctc_arch, artifact, args.branch))
|
||||
sys.exit(0)
|
||||
|
||||
if args.source is not None:
|
||||
if args.source in DEFAULT_SCHEMES:
|
||||
global TASKCLUSTER_SCHEME
|
||||
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
|
||||
else:
|
||||
print('No such scheme: %s' % args.source)
|
||||
sys.exit(1)
|
||||
|
||||
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
|
||||
|
||||
if args.artifact == "convert_graphdef_memmapped_format":
|
||||
convert_graph_file = os.path.join(args.target, args.artifact)
|
||||
final_stat = os.stat(convert_graph_file)
|
||||
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
if '.tar.' in args.artifact:
|
||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -12,13 +12,13 @@ tflogging.set_verbosity(tflogging.ERROR)
|
||||
import logging
|
||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||
|
||||
from multiprocessing import Process, cpu_count
|
||||
from deepspeech_training.util.audio import AudioFile
|
||||
from deepspeech_training.util.config import Config, initialize_globals
|
||||
from deepspeech_training.util.feeding import split_audio_file
|
||||
from deepspeech_training.util.flags import create_flags, FLAGS
|
||||
from deepspeech_training.util.logging import log_error, log_info, log_progress, create_progressbar
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from util.config import Config, initialize_globals
|
||||
from util.audio import AudioFile
|
||||
from util.feeding import split_audio_file
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error, log_info, log_progress, create_progressbar
|
||||
from multiprocessing import Process, cpu_count
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
@ -27,8 +27,8 @@ def fail(message, code=1):
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from deepspeech_training.util.checkpoints import load_or_init_graph
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
|
@ -1,168 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function, absolute_import, division
|
||||
|
||||
import argparse
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import errno
|
||||
import stat
|
||||
import gzip
|
||||
|
||||
import six.moves.urllib as urllib
|
||||
|
||||
from pkg_resources import parse_version
|
||||
|
||||
|
||||
DEFAULT_SCHEMES = {
|
||||
'deepspeech': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.deepspeech.native_client.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s',
|
||||
'tensorflow': 'https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.%(branch_name)s.%(arch_string)s/artifacts/public/%(artifact_name)s'
|
||||
}
|
||||
|
||||
TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech'])
|
||||
|
||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||
assert arch_string is not None
|
||||
assert artifact_name is not None
|
||||
assert artifact_name
|
||||
assert branch_name is not None
|
||||
assert branch_name
|
||||
|
||||
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||
|
||||
def maybe_download_tc(target_dir, tc_url, progress=True):
|
||||
def report_progress(count, block_size, total_size):
|
||||
percent = (count * block_size * 100) // total_size
|
||||
sys.stdout.write("\rDownloading: %d%%" % percent)
|
||||
sys.stdout.flush()
|
||||
|
||||
if percent >= 100:
|
||||
print('\n')
|
||||
|
||||
assert target_dir is not None
|
||||
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
try:
|
||||
os.makedirs(target_dir)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise e
|
||||
assert os.path.isdir(os.path.dirname(target_dir))
|
||||
|
||||
tc_filename = os.path.basename(tc_url)
|
||||
target_file = os.path.join(target_dir, tc_filename)
|
||||
is_gzip = False
|
||||
if not os.path.isfile(target_file):
|
||||
print('Downloading %s ...' % tc_url)
|
||||
_, headers = urllib.request.urlretrieve(tc_url, target_file, reporthook=(report_progress if progress else None))
|
||||
is_gzip = headers.get('Content-Encoding') == 'gzip'
|
||||
else:
|
||||
print('File already exists: %s' % target_file)
|
||||
|
||||
if is_gzip:
|
||||
with open(target_file, "r+b") as frw:
|
||||
decompressed = gzip.decompress(frw.read())
|
||||
frw.seek(0)
|
||||
frw.write(decompressed)
|
||||
frw.truncate()
|
||||
|
||||
return target_file
|
||||
|
||||
def maybe_download_tc_bin(**kwargs):
|
||||
final_file = maybe_download_tc(kwargs['target_dir'], kwargs['tc_url'], kwargs['progress'])
|
||||
final_stat = os.stat(final_file)
|
||||
os.chmod(final_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='Where to put the native client binary files')
|
||||
parser.add_argument('--arch', required=False,
|
||||
help='Which architecture to download binaries for. "arm" for ARM 7 (32-bit), "arm64" for ARM64, "gpu" for CUDA enabled x86_64 binaries, "cpu" for CPU-only x86_64 binaries, "osx" for CPU-only x86_64 OSX binaries. Optional ("cpu" by default)')
|
||||
parser.add_argument('--artifact', required=False,
|
||||
default='native_client.tar.xz',
|
||||
help='Name of the artifact to download. Defaults to "native_client.tar.xz"')
|
||||
parser.add_argument('--source', required=False, default=None,
|
||||
help='Name of the TaskCluster scheme to use.')
|
||||
parser.add_argument('--branch', required=False,
|
||||
help='Branch name to use. Defaulting to current content of VERSION file.')
|
||||
parser.add_argument('--decoder', action='store_true',
|
||||
help='Get URL to ds_ctcdecoder Python package.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.target and not args.decoder:
|
||||
print('Pass either --target or --decoder.')
|
||||
exit(1)
|
||||
|
||||
is_arm = 'arm' in platform.machine()
|
||||
is_mac = 'darwin' in sys.platform
|
||||
is_64bit = sys.maxsize > (2**31 - 1)
|
||||
is_ucs2 = sys.maxunicode < 0x10ffff
|
||||
|
||||
if not args.arch:
|
||||
if is_arm:
|
||||
args.arch = 'arm64' if is_64bit else 'arm'
|
||||
elif is_mac:
|
||||
args.arch = 'osx'
|
||||
else:
|
||||
args.arch = 'cpu'
|
||||
|
||||
if not args.branch:
|
||||
version_string = read('../VERSION').strip()
|
||||
ds_version = parse_version(version_string)
|
||||
args.branch = "v{}".format(version_string)
|
||||
else:
|
||||
ds_version = parse_version(args.branch)
|
||||
|
||||
if args.decoder:
|
||||
plat = platform.system().lower()
|
||||
arch = platform.machine()
|
||||
|
||||
if plat == 'linux' and arch == 'x86_64':
|
||||
plat = 'manylinux1'
|
||||
|
||||
if plat == 'darwin':
|
||||
plat = 'macosx_10_10'
|
||||
|
||||
m_or_mu = 'mu' if is_ucs2 else 'm'
|
||||
pyver = ''.join(map(str, sys.version_info[0:2]))
|
||||
|
||||
artifact = "ds_ctcdecoder-{ds_version}-cp{pyver}-cp{pyver}{m_or_mu}-{platform}_{arch}.whl".format(
|
||||
ds_version=ds_version,
|
||||
pyver=pyver,
|
||||
m_or_mu=m_or_mu,
|
||||
platform=plat,
|
||||
arch=arch
|
||||
)
|
||||
|
||||
ctc_arch = args.arch + '-ctc'
|
||||
|
||||
print(get_tc_url(ctc_arch, artifact, args.branch))
|
||||
exit(0)
|
||||
|
||||
if args.source is not None:
|
||||
if args.source in DEFAULT_SCHEMES:
|
||||
global TASKCLUSTER_SCHEME
|
||||
TASKCLUSTER_SCHEME = DEFAULT_SCHEMES[args.source]
|
||||
else:
|
||||
print('No such scheme: %s' % args.source)
|
||||
exit(1)
|
||||
|
||||
maybe_download_tc(target_dir=args.target, tc_url=get_tc_url(args.arch, args.artifact, args.branch))
|
||||
|
||||
if args.artifact == "convert_graphdef_memmapped_format":
|
||||
convert_graph_file = os.path.join(args.target, args.artifact)
|
||||
final_stat = os.stat(convert_graph_file)
|
||||
os.chmod(convert_graph_file, final_stat.st_mode | stat.S_IEXEC)
|
||||
|
||||
if '.tar.' in args.artifact:
|
||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
try:
|
||||
from deepspeech_training.util import taskcluster as dsu_taskcluster
|
||||
except ImportError:
|
||||
print('Training package is not installed. See training documentation.')
|
||||
raise
|
||||
|
||||
dsu_taskcluster.main()
|
||||
|
Loading…
Reference in New Issue
Block a user