Improve training startup time

This commit is contained in:
Reuben Morais 2019-10-29 11:28:47 +01:00
parent 26d10a5df3
commit b39da7f8b7
2 changed files with 10 additions and 20 deletions

View File

@ -19,17 +19,15 @@ from util.flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
def read_csvs(csv_files):
source_data = None
sets = []
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
if source_data is None:
source_data = file
else:
source_data = source_data.append(file, ignore_index=True)
return source_data
sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True)
def samples_to_mfccs(samples, sample_rate, train_phase=False):
@ -100,13 +98,7 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
df = read_csvs(csvs)
df.sort_values(by='wav_filesize', inplace=True)
try:
# Convert to character index arrays
df = df.apply(partial(text_to_char_array, alphabet=Config.alphabet), result_type='broadcast', axis=1)
except ValueError as e:
error_message, series, *_ = e.args
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
exit(1)
df['transcript'] = df.apply(text_to_char_array, alphabet=Config.alphabet, result_type='reduce', axis=1)
def generate_values():
for _, row in df.iterrows():

View File

@ -64,15 +64,13 @@ def text_to_char_array(series, alphabet):
integers and return a numpy array representing the processed string.
"""
try:
series['transcript'] = np.asarray(alphabet.encode(series['transcript']))
transcript = np.asarray(alphabet.encode(series['transcript']))
if not len(transcript):
raise ValueError('While processing: {}\nFound an empty transcript! You must include a transcript for all training data.'.format(series['wav_filename']))
return transcript
except KeyError as e:
# Provide the row context (especially wav_filename) for alphabet errors
raise ValueError(str(e), series)
if series['transcript'].shape[0] == 0:
raise ValueError("Found an empty transcript! You must include a transcript for all training data.", series)
return series
raise ValueError('While processing: {}\n{}'.format(series['wav_filename'], e))
# The following code is from: http://hetland.org/coding/python/levenshtein.py