Fix pylint warnings
This commit is contained in:
parent
a16e468498
commit
13757a4258
@ -5,17 +5,17 @@ from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
log_level_index = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_index > 0 and log_level_index < len(sys.argv) else '3'
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
|
||||
import time
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
from util.config import Config, initialize_globals
|
||||
@ -49,7 +49,7 @@ def create_overlapping_windows(batch_x):
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME')
|
||||
@ -172,7 +172,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None,
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, tower, dropout, reuse):
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
@ -246,10 +246,10 @@ def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i) as scope:
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss = calculate_mean_edit_distance_and_loss(iterator, i, dropout_rates, reuse=i>0)
|
||||
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
@ -460,9 +460,9 @@ def train():
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data):
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data)
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
|
||||
@ -547,7 +547,7 @@ def train():
|
||||
|
||||
|
||||
def test():
|
||||
evaluate.evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
@ -570,12 +570,12 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = previous_state_c = previous_state_h = None
|
||||
else:
|
||||
if not tflite:
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
else:
|
||||
if tflite:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
else:
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
@ -620,28 +620,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
if not tflite:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'input_samples': input_samples,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
'initialize_state': initialize_state,
|
||||
'mfccs': mfccs,
|
||||
},
|
||||
layers
|
||||
)
|
||||
else:
|
||||
if tflite:
|
||||
logits = tf.identity(logits, name='logits')
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
@ -656,17 +635,32 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
if FLAGS.use_seq_length:
|
||||
inputs.update({'input_lengths': seq_length})
|
||||
|
||||
return (
|
||||
inputs,
|
||||
{
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
},
|
||||
layers
|
||||
)
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
else:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'initialize_state': initialize_state,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
@ -680,11 +674,9 @@ def export():
|
||||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
input_names = ",".join(tensor.op.name for tensor in inputs.values())
|
||||
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation)]
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
@ -828,6 +820,6 @@ def main(_):
|
||||
tf.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
if __name__ == '__main__' :
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
tf.app.run(main)
|
||||
|
28
evaluate.py
28
evaluate.py
@ -4,13 +4,16 @@ from __future__ import absolute_import, division, print_function
|
||||
|
||||
import itertools
|
||||
import json
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from multiprocessing import cpu_count
|
||||
from six.moves import zip, range
|
||||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.evaluate_tools import calculate_report
|
||||
from util.feeding import create_dataset
|
||||
@ -27,13 +30,12 @@ def sparse_tensor_value_to_texts(value, alphabet):
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(tuple, alphabet):
|
||||
indices = tuple[0]
|
||||
values = tuple[1]
|
||||
results = [''] * tuple[2][0]
|
||||
for i in range(len(indices)):
|
||||
index = indices[i][0]
|
||||
results[index] += alphabet.string_from_label(values[i])
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [''] * sp_tuple[2][0]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]] += alphabet.string_from_label(values[i])
|
||||
# List of strings
|
||||
return results
|
||||
|
||||
@ -63,7 +65,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
tf.train.get_or_create_global_step()
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
# Create a saver using variables from the above newly created graph
|
||||
@ -109,7 +111,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except:
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
print('Decoding predictions...')
|
||||
@ -151,12 +153,12 @@ def main(_):
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
|
||||
from DeepSpeech import create_model, try_loading
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=lambda x: float(x))
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,55 +1,56 @@
|
||||
import csv
|
||||
import sys
|
||||
import glob
|
||||
|
||||
"""
|
||||
Usage: $ python3 check_characters.py "INFILE"
|
||||
e.g. $ python3 check_characters.py -csv /home/data/french.csv
|
||||
e.g. $ python3 check_characters.py -csv ../train.csv,../test.csv
|
||||
e.g. $ python3 check_characters.py -alpha -csv ../train.csv
|
||||
e.g. $ python3 check_characters.py -csv ../train.csv,../test.csv
|
||||
e.g. $ python3 check_characters.py -alpha -csv ../train.csv
|
||||
|
||||
Point this script to your transcripts, and it returns
|
||||
to the terminal the unique set of characters in those
|
||||
Point this script to your transcripts, and it returns
|
||||
to the terminal the unique set of characters in those
|
||||
files (combined).
|
||||
|
||||
These files are assumed to be csv, with the transcript being the third field.
|
||||
|
||||
The script simply reads all the text from all the files,
|
||||
storing a set of unique characters that were seen
|
||||
The script simply reads all the text from all the files,
|
||||
storing a set of unique characters that were seen
|
||||
along the way.
|
||||
"""
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
|
||||
parser.add_argument("-alpha", "--alphabet-format",help="Bool. Print in format for alphabet.txt",action="store_true")
|
||||
parser.set_defaults(alphabet_format=False)
|
||||
args = parser.parse_args()
|
||||
inFiles = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
|
||||
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
|
||||
args = parser.parse_args()
|
||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
|
||||
print("### Reading in the following transcript files: ###")
|
||||
print("### {} ###".format(inFiles))
|
||||
print("### Reading in the following transcript files: ###")
|
||||
print("### {} ###".format(in_files))
|
||||
|
||||
allText = set()
|
||||
for inFile in (inFiles):
|
||||
with open(inFile, "r") as csvFile:
|
||||
reader = csv.reader(csvFile)
|
||||
try:
|
||||
next(reader, None) # skip the file header (i.e. "transcript")
|
||||
for row in reader:
|
||||
allText |= set(str(row[2]))
|
||||
except IndexError as ie:
|
||||
print("Your input file",inFile,"is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript")
|
||||
sys.exit(-1)
|
||||
finally:
|
||||
csvFile.close()
|
||||
all_text = set()
|
||||
for in_file in in_files:
|
||||
with open(in_file, "r") as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
try:
|
||||
next(reader, None) # skip the file header (i.e. "transcript")
|
||||
for row in reader:
|
||||
all_text |= set(str(row[2]))
|
||||
except IndexError:
|
||||
print("Your input file", in_file, "is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript")
|
||||
sys.exit(-1)
|
||||
finally:
|
||||
csv_file.close()
|
||||
|
||||
print("### The following unique characters were found in your transcripts: ###")
|
||||
if args.alphabet_format:
|
||||
for char in list(allText):
|
||||
print(char)
|
||||
print("### ^^^ You can copy-paste these into data/alphabet.txt ###")
|
||||
else:
|
||||
print(list(allText))
|
||||
print("### The following unique characters were found in your transcripts: ###")
|
||||
if args.alphabet_format:
|
||||
for char in list(all_text):
|
||||
print(char)
|
||||
print("### ^^^ You can copy-paste these into data/alphabet.txt ###")
|
||||
else:
|
||||
print(list(all_text))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -4,11 +4,12 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
from attrdict import AttrDict
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
from util.text import Alphabet
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
@ -21,7 +22,7 @@ class ConfigSingleton:
|
||||
return ConfigSingleton._config[name]
|
||||
|
||||
|
||||
Config = ConfigSingleton()
|
||||
Config = ConfigSingleton() # pylint: disable=invalid-name
|
||||
|
||||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
@ -33,7 +34,7 @@ def initialize_globals():
|
||||
c.available_devices = get_available_gpus()
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if 0 == len(c.available_devices):
|
||||
if not c.available_devices:
|
||||
c.available_devices = [c.cpu_device]
|
||||
|
||||
# Set default dropout rates
|
||||
@ -45,15 +46,15 @@ def initialize_globals():
|
||||
FLAGS.dropout_rate6 = FLAGS.dropout_rate
|
||||
|
||||
# Set default checkpoint dir
|
||||
if len(FLAGS.checkpoint_dir) == 0:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))
|
||||
if not FLAGS.checkpoint_dir == 0:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
|
||||
|
||||
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load = 'auto'
|
||||
|
||||
# Set default summary dir
|
||||
if len(FLAGS.summary_dir) == 0:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))
|
||||
if not FLAGS.summary_dir:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries'))
|
||||
|
||||
# Standard session configuration that'll be used for all new sessions.
|
||||
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
|
||||
@ -103,4 +104,4 @@ def initialize_globals():
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
exit(1)
|
||||
|
||||
ConfigSingleton._config = c
|
||||
ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
@ -2,8 +2,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from attrdict import AttrDict
|
||||
from multiprocessing.dummy import Pool
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
||||
from util.text import wer_cer_batch, levenshtein
|
||||
|
||||
def pmap(fun, iterable):
|
||||
|
@ -1,13 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
import tensorflow as tf
|
||||
|
||||
from functools import partial
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
|
||||
from util.config import Config
|
||||
from util.text import text_to_char_array
|
||||
|
||||
@ -18,7 +21,7 @@ def read_csvs(csv_files):
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
#FIXME: not cross-platform
|
||||
csv_dir = os.path.dirname(os.path.abspath(csv))
|
||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
|
||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
|
||||
if source_data is None:
|
||||
source_data = file
|
||||
else:
|
||||
|
126
util/flags.py
126
util/flags.py
@ -10,110 +10,110 @@ def create_flags():
|
||||
# Importer
|
||||
# ========
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
|
||||
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
|
||||
f = tf.app.flags
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
|
||||
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
|
||||
tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model')
|
||||
f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
|
||||
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
tf.app.flags.DEFINE_integer ('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
|
||||
f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
|
||||
|
||||
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
|
||||
f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
||||
f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
|
||||
f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
|
||||
f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
|
||||
f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
|
||||
|
||||
tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
|
||||
f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
|
||||
|
||||
# Adam optimizer (http://arxiv.org/abs/1412.6980) parameters
|
||||
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
|
||||
|
||||
tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer')
|
||||
f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
|
||||
f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
|
||||
f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
|
||||
f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer')
|
||||
|
||||
# Batch sizes
|
||||
|
||||
tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch')
|
||||
tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch')
|
||||
tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch')
|
||||
f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch')
|
||||
f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch')
|
||||
f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch')
|
||||
|
||||
tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
||||
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
||||
|
||||
# Performance (UNSUPPORTED)
|
||||
tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
||||
tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
||||
# Performance(UNSUPPORTED)
|
||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
||||
|
||||
# Sample limits
|
||||
|
||||
tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
||||
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
||||
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
||||
|
||||
# Checkpointing
|
||||
|
||||
tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
tf.app.flags.DEFINE_string ('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
||||
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
||||
|
||||
# Exporting
|
||||
|
||||
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
|
||||
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
|
||||
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
||||
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
tf.app.flags.DEFINE_string ('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
f.DEFINE_integer('export_version', 1, 'version number of the exported model')
|
||||
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
|
||||
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
f.DEFINE_boolean('use_seq_length', True, 'have sequence_length in the exported graph(will make tfcompile unhappy)')
|
||||
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
|
||||
# Reporting
|
||||
|
||||
tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report')
|
||||
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report')
|
||||
|
||||
tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
|
||||
# Geometry
|
||||
|
||||
tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers')
|
||||
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers')
|
||||
|
||||
# Initialization
|
||||
|
||||
tf.app.flags.DEFINE_integer ('random_seed', 4568, 'default random seed that is used to initialize variables')
|
||||
f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables')
|
||||
|
||||
# Early Stopping
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
||||
tf.app.flags.DEFINE_integer ('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
tf.app.flags.DEFINE_float ('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_float ('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
||||
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
|
||||
# Decoder
|
||||
|
||||
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
||||
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
||||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||
tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||
f.DEFINE_string('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
||||
f.DEFINE_string('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
||||
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
|
||||
# Inference mode
|
||||
|
||||
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|
||||
|
||||
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|
||||
|
@ -27,4 +27,4 @@ def log_warn(message):
|
||||
|
||||
def log_error(message):
|
||||
if FLAGS.log_level <= 3:
|
||||
prefix_print('E ', message)
|
||||
prefix_print('E ', message)
|
||||
|
@ -2,12 +2,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function, absolute_import, division
|
||||
|
||||
import argparse
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import errno
|
||||
import stat
|
||||
|
||||
import six.moves.urllib as urllib
|
||||
|
||||
from pkg_resources import parse_version
|
||||
@ -23,9 +25,9 @@ TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech
|
||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||
assert arch_string is not None
|
||||
assert artifact_name is not None
|
||||
assert len(artifact_name) > 0
|
||||
assert artifact_name
|
||||
assert branch_name is not None
|
||||
assert len(branch_name) > 0
|
||||
assert branch_name
|
||||
|
||||
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||
|
||||
@ -66,9 +68,7 @@ def maybe_download_tc_bin(**kwargs):
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='Where to put the native client binary files')
|
||||
@ -151,3 +151,6 @@ if __name__ == '__main__':
|
||||
|
||||
if '.tar.' in args.artifact:
|
||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
25
util/text.py
25
util/text.py
@ -1,9 +1,9 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import numpy as np
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
from six.moves import range
|
||||
|
||||
@ -33,7 +33,6 @@ class Alphabet(object):
|
||||
raise KeyError(
|
||||
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
|
||||
).with_traceback(e.__traceback__)
|
||||
sys.exit()
|
||||
|
||||
def decode(self, labels):
|
||||
res = ''
|
||||
@ -94,18 +93,18 @@ def wer_cer_batch(originals, results):
|
||||
# version 1.0. This software is distributed without any warranty. For more
|
||||
# information, see <http://creativecommons.org/publicdomain/zero/1.0>
|
||||
|
||||
def levenshtein(a,b):
|
||||
def levenshtein(a, b):
|
||||
"Calculates the Levenshtein distance between a and b."
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
# Make sure n <= m, to use O(min(n,m)) space
|
||||
a,b = b,a
|
||||
n,m = m,n
|
||||
a, b = b, a
|
||||
n, m = m, n
|
||||
|
||||
current = list(range(n+1))
|
||||
for i in range(1,m+1):
|
||||
for i in range(1, m+1):
|
||||
previous, current = current, [i]+[0]*n
|
||||
for j in range(1,n+1):
|
||||
for j in range(1, n+1):
|
||||
add, delete = previous[j]+1, current[j-1]+1
|
||||
change = previous[j-1]
|
||||
if a[j-1] != b[i-1]:
|
||||
@ -118,14 +117,7 @@ def levenshtein(a,b):
|
||||
# or None if it's invalid.
|
||||
def validate_label(label):
|
||||
# For now we can only handle [a-z ']
|
||||
if "(" in label or \
|
||||
"<" in label or \
|
||||
"[" in label or \
|
||||
"]" in label or \
|
||||
"&" in label or \
|
||||
"*" in label or \
|
||||
"{" in label or \
|
||||
re.search(r"[0-9]", label) != None:
|
||||
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
|
||||
return None
|
||||
|
||||
label = label.replace("-", "")
|
||||
@ -138,4 +130,3 @@ def validate_label(label):
|
||||
label = label.lower()
|
||||
|
||||
return label if label else None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user