Fix pylint warnings

This commit is contained in:
Reuben Morais 2019-04-11 07:02:21 -03:00
parent a16e468498
commit 13757a4258
10 changed files with 195 additions and 200 deletions

View File

@ -5,17 +5,17 @@ from __future__ import absolute_import, division, print_function
import os
import sys
log_level_index = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_index > 0 and log_level_index < len(sys.argv) else '3'
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
import time
import evaluate
import numpy as np
import progressbar
import shutil
import tensorflow as tf
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate
from six.moves import zip, range
from tensorflow.python.tools import freeze_graph
from util.config import Config, initialize_globals
@ -49,7 +49,7 @@ def create_overlapping_windows(batch_x):
# convolution returns patches of the input tensor as is, and we can create
# overlapping windows over the MFCCs.
eye_filter = tf.constant(np.eye(window_width * num_channels)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32)
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
# Create overlapping windows
batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME')
@ -172,7 +172,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None,
# Conveniently, this loss function is implemented in TensorFlow.
# Thus, we can simply make use of this implementation to define our loss.
def calculate_mean_edit_distance_and_loss(iterator, tower, dropout, reuse):
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
@ -246,10 +246,10 @@ def get_tower_results(iterator, optimizer, dropout_rates):
device = Config.available_devices[i]
with tf.device(device):
# Create a scope for all operations of tower i
with tf.name_scope('tower_%d' % i) as scope:
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss = calculate_mean_edit_distance_and_loss(iterator, i, dropout_rates, reuse=i>0)
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
# Allow for variables to be re-used by the next tower
tf.get_variable_scope().reuse_variables()
@ -460,9 +460,9 @@ def train():
def __init__(self):
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
def __call__(self, progress, data):
def __call__(self, progress, data, **kwargs):
data['mean_loss'] = total_loss / step_count if step_count else 0.0
return progressbar.widgets.FormatLabel.__call__(self, progress, data)
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
if FLAGS.show_progressbar:
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
@ -547,7 +547,7 @@ def train():
def test():
evaluate.evaluate(FLAGS.test_files.split(','), create_model, try_loading)
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
@ -570,12 +570,12 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None
else:
if not tflite:
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
else:
if tflite:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
else:
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
@ -620,28 +620,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
)
new_state_c, new_state_h = layers['rnn_output_state']
if not tflite:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
return (
{
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
},
{
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
},
layers
)
else:
if tflite:
logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
@ -656,17 +635,32 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})
return (
inputs,
{
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
},
layers
)
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
else:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
inputs = {
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
}
outputs = {
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
}
return inputs, outputs, layers
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
@ -680,11 +674,9 @@ def export():
from tensorflow.python.framework.ops import Tensor, Operation
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation)]
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
@ -828,6 +820,6 @@ def main(_):
tf.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__' :
if __name__ == '__main__':
create_flags()
tf.app.run(main)

View File

@ -4,13 +4,16 @@ from __future__ import absolute_import, division, print_function
import itertools
import json
from multiprocessing import cpu_count
import numpy as np
import progressbar
import tensorflow as tf
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from multiprocessing import cpu_count
from six.moves import zip, range
from six.moves import zip
from util.config import Config, initialize_globals
from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
@ -27,13 +30,12 @@ def sparse_tensor_value_to_texts(value, alphabet):
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
def sparse_tuple_to_texts(tuple, alphabet):
indices = tuple[0]
values = tuple[1]
results = [''] * tuple[2][0]
for i in range(len(indices)):
index = indices[i][0]
results[index] += alphabet.string_from_label(values[i])
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [''] * sp_tuple[2][0]
for i, index in enumerate(indices):
results[index[0]] += alphabet.string_from_label(values[i])
# List of strings
return results
@ -63,7 +65,7 @@ def evaluate(test_csvs, create_model, try_loading):
inputs=logits,
sequence_length=batch_x_len)
global_step = tf.train.get_or_create_global_step()
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
# Create a saver using variables from the above newly created graph
@ -109,7 +111,7 @@ def evaluate(test_csvs, create_model, try_loading):
# Get number of accessible CPU cores for this process
try:
num_processes = cpu_count()
except:
except NotImplementedError:
num_processes = 1
print('Decoding predictions...')
@ -151,12 +153,12 @@ def main(_):
'the --test_files flag.')
exit(1)
from DeepSpeech import create_model, try_loading
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=lambda x: float(x))
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
if __name__ == '__main__':

View File

@ -1,55 +1,56 @@
import csv
import sys
import glob
"""
Usage: $ python3 check_characters.py "INFILE"
e.g. $ python3 check_characters.py -csv /home/data/french.csv
e.g. $ python3 check_characters.py -csv ../train.csv,../test.csv
e.g. $ python3 check_characters.py -alpha -csv ../train.csv
e.g. $ python3 check_characters.py -csv ../train.csv,../test.csv
e.g. $ python3 check_characters.py -alpha -csv ../train.csv
Point this script to your transcripts, and it returns
to the terminal the unique set of characters in those
Point this script to your transcripts, and it returns
to the terminal the unique set of characters in those
files (combined).
These files are assumed to be csv, with the transcript being the third field.
The script simply reads all the text from all the files,
storing a set of unique characters that were seen
The script simply reads all the text from all the files,
storing a set of unique characters that were seen
along the way.
"""
import argparse
import csv
import os
import sys
parser = argparse.ArgumentParser()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
parser.add_argument("-alpha", "--alphabet-format",help="Bool. Print in format for alphabet.txt",action="store_true")
parser.set_defaults(alphabet_format=False)
args = parser.parse_args()
inFiles = [os.path.abspath(i) for i in args.csv_files.split(",")]
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
args = parser.parse_args()
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
print("### Reading in the following transcript files: ###")
print("### {} ###".format(inFiles))
print("### Reading in the following transcript files: ###")
print("### {} ###".format(in_files))
allText = set()
for inFile in (inFiles):
with open(inFile, "r") as csvFile:
reader = csv.reader(csvFile)
try:
next(reader, None) # skip the file header (i.e. "transcript")
for row in reader:
allText |= set(str(row[2]))
except IndexError as ie:
print("Your input file",inFile,"is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript")
sys.exit(-1)
finally:
csvFile.close()
all_text = set()
for in_file in in_files:
with open(in_file, "r") as csv_file:
reader = csv.reader(csv_file)
try:
next(reader, None) # skip the file header (i.e. "transcript")
for row in reader:
all_text |= set(str(row[2]))
except IndexError:
print("Your input file", in_file, "is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript")
sys.exit(-1)
finally:
csv_file.close()
print("### The following unique characters were found in your transcripts: ###")
if args.alphabet_format:
for char in list(allText):
print(char)
print("### ^^^ You can copy-paste these into data/alphabet.txt ###")
else:
print(list(allText))
print("### The following unique characters were found in your transcripts: ###")
if args.alphabet_format:
for char in list(all_text):
print(char)
print("### ^^^ You can copy-paste these into data/alphabet.txt ###")
else:
print(list(all_text))
if __name__ == '__main__':
main()

View File

@ -4,11 +4,12 @@ import os
import tensorflow as tf
from attrdict import AttrDict
from xdg import BaseDirectory as xdg
from util.flags import FLAGS
from util.gpu import get_available_gpus
from util.logging import log_error
from util.text import Alphabet
from xdg import BaseDirectory as xdg
class ConfigSingleton:
_config = None
@ -21,7 +22,7 @@ class ConfigSingleton:
return ConfigSingleton._config[name]
Config = ConfigSingleton()
Config = ConfigSingleton() # pylint: disable=invalid-name
def initialize_globals():
c = AttrDict()
@ -33,7 +34,7 @@ def initialize_globals():
c.available_devices = get_available_gpus()
# If there is no GPU available, we fall back to CPU based operation
if 0 == len(c.available_devices):
if not c.available_devices:
c.available_devices = [c.cpu_device]
# Set default dropout rates
@ -45,15 +46,15 @@ def initialize_globals():
FLAGS.dropout_rate6 = FLAGS.dropout_rate
# Set default checkpoint dir
if len(FLAGS.checkpoint_dir) == 0:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))
if not FLAGS.checkpoint_dir == 0:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
FLAGS.load = 'auto'
# Set default summary dir
if len(FLAGS.summary_dir) == 0:
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))
if not FLAGS.summary_dir:
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries'))
# Standard session configuration that'll be used for all new sessions.
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
@ -103,4 +104,4 @@ def initialize_globals():
log_error('Path specified in --one_shot_infer is not a valid file.')
exit(1)
ConfigSingleton._config = c
ConfigSingleton._config = c # pylint: disable=protected-access

View File

@ -2,8 +2,10 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
from attrdict import AttrDict
from multiprocessing.dummy import Pool
from attrdict import AttrDict
from util.text import wer_cer_batch, levenshtein
def pmap(fun, iterable):

View File

@ -1,13 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import numpy as np
import os
from functools import partial
import numpy as np
import pandas
import tensorflow as tf
from functools import partial
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
from util.config import Config
from util.text import text_to_char_array
@ -18,7 +21,7 @@ def read_csvs(csv_files):
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
if source_data is None:
source_data = file
else:

View File

@ -10,110 +10,110 @@ def create_flags():
# Importer
# ========
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
f = tf.app.flags
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds')
tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model')
f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
# Global Constants
# ================
tf.app.flags.DEFINE_integer ('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers')
f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
# Adam optimizer (http://arxiv.org/abs/1412.6980) parameters
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer')
f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer')
# Batch sizes
tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch')
tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch')
tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch')
f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch')
f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch')
f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch')
tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph')
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
# Performance (UNSUPPORTED)
tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
# Performance(UNSUPPORTED)
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
# Sample limits
tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
# Checkpointing
tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
tf.app.flags.DEFINE_string ('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
# Exporting
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
tf.app.flags.DEFINE_string ('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
f.DEFINE_integer('export_version', 1, 'version number of the exported model')
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
f.DEFINE_boolean('use_seq_length', True, 'have sequence_length in the exported graph(will make tfcompile unhappy)')
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
# Reporting
tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console')
tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report')
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report')
tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
# Geometry
tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers')
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers')
# Initialization
tf.app.flags.DEFINE_integer ('random_seed', 4568, 'default random seed that is used to initialize variables')
f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables')
# Early Stopping
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
tf.app.flags.DEFINE_integer ('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
tf.app.flags.DEFINE_float ('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
tf.app.flags.DEFINE_float ('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
# Decoder
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
f.DEFINE_string('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
f.DEFINE_string('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
# Inference mode
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')

View File

@ -27,4 +27,4 @@ def log_warn(message):
def log_error(message):
if FLAGS.log_level <= 3:
prefix_print('E ', message)
prefix_print('E ', message)

View File

@ -2,12 +2,14 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import, division
import argparse
import platform
import subprocess
import sys
import os
import errno
import stat
import six.moves.urllib as urllib
from pkg_resources import parse_version
@ -23,9 +25,9 @@ TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
assert arch_string is not None
assert artifact_name is not None
assert len(artifact_name) > 0
assert artifact_name
assert branch_name is not None
assert len(branch_name) > 0
assert branch_name
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
@ -66,9 +68,7 @@ def maybe_download_tc_bin(**kwargs):
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
if __name__ == '__main__':
import argparse
def main():
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
parser.add_argument('--target', required=False,
help='Where to put the native client binary files')
@ -151,3 +151,6 @@ if __name__ == '__main__':
if '.tar.' in args.artifact:
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
if __name__ == '__main__':
main()

View File

@ -1,9 +1,9 @@
from __future__ import absolute_import, division, print_function
import codecs
import numpy as np
import re
import sys
import numpy as np
from six.moves import range
@ -33,7 +33,6 @@ class Alphabet(object):
raise KeyError(
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
).with_traceback(e.__traceback__)
sys.exit()
def decode(self, labels):
res = ''
@ -94,18 +93,18 @@ def wer_cer_batch(originals, results):
# version 1.0. This software is distributed without any warranty. For more
# information, see <http://creativecommons.org/publicdomain/zero/1.0>
def levenshtein(a,b):
def levenshtein(a, b):
"Calculates the Levenshtein distance between a and b."
n, m = len(a), len(b)
if n > m:
# Make sure n <= m, to use O(min(n,m)) space
a,b = b,a
n,m = m,n
a, b = b, a
n, m = m, n
current = list(range(n+1))
for i in range(1,m+1):
for i in range(1, m+1):
previous, current = current, [i]+[0]*n
for j in range(1,n+1):
for j in range(1, n+1):
add, delete = previous[j]+1, current[j-1]+1
change = previous[j-1]
if a[j-1] != b[i-1]:
@ -118,14 +117,7 @@ def levenshtein(a,b):
# or None if it's invalid.
def validate_label(label):
# For now we can only handle [a-z ']
if "(" in label or \
"<" in label or \
"[" in label or \
"]" in label or \
"&" in label or \
"*" in label or \
"{" in label or \
re.search(r"[0-9]", label) != None:
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
return None
label = label.replace("-", "")
@ -138,4 +130,3 @@ def validate_label(label):
label = label.lower()
return label if label else None