Improve training startup time
This commit is contained in:
parent
26d10a5df3
commit
b39da7f8b7
@ -19,17 +19,15 @@ from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
|
||||
|
||||
def read_csvs(csv_files):
|
||||
source_data = None
|
||||
sets = []
|
||||
for csv in csv_files:
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
#FIXME: not cross-platform
|
||||
csv_dir = os.path.dirname(os.path.abspath(csv))
|
||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
|
||||
if source_data is None:
|
||||
source_data = file
|
||||
else:
|
||||
source_data = source_data.append(file, ignore_index=True)
|
||||
return source_data
|
||||
sets.append(file)
|
||||
# Concat all sets, drop any extra columns, re-index the final result as 0..N
|
||||
return pandas.concat(sets, join='inner', ignore_index=True)
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate, train_phase=False):
|
||||
@ -100,13 +98,7 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||
df = read_csvs(csvs)
|
||||
df.sort_values(by='wav_filesize', inplace=True)
|
||||
|
||||
try:
|
||||
# Convert to character index arrays
|
||||
df = df.apply(partial(text_to_char_array, alphabet=Config.alphabet), result_type='broadcast', axis=1)
|
||||
except ValueError as e:
|
||||
error_message, series, *_ = e.args
|
||||
log_error('While processing {}:\n {}'.format(series['wav_filename'], error_message))
|
||||
exit(1)
|
||||
df['transcript'] = df.apply(text_to_char_array, alphabet=Config.alphabet, result_type='reduce', axis=1)
|
||||
|
||||
def generate_values():
|
||||
for _, row in df.iterrows():
|
||||
|
12
util/text.py
12
util/text.py
@ -64,15 +64,13 @@ def text_to_char_array(series, alphabet):
|
||||
integers and return a numpy array representing the processed string.
|
||||
"""
|
||||
try:
|
||||
series['transcript'] = np.asarray(alphabet.encode(series['transcript']))
|
||||
transcript = np.asarray(alphabet.encode(series['transcript']))
|
||||
if not len(transcript):
|
||||
raise ValueError('While processing: {}\nFound an empty transcript! You must include a transcript for all training data.'.format(series['wav_filename']))
|
||||
return transcript
|
||||
except KeyError as e:
|
||||
# Provide the row context (especially wav_filename) for alphabet errors
|
||||
raise ValueError(str(e), series)
|
||||
|
||||
if series['transcript'].shape[0] == 0:
|
||||
raise ValueError("Found an empty transcript! You must include a transcript for all training data.", series)
|
||||
|
||||
return series
|
||||
raise ValueError('While processing: {}\n{}'.format(series['wav_filename'], e))
|
||||
|
||||
|
||||
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
||||
|
Loading…
x
Reference in New Issue
Block a user