From 11782154232d68dea0c71d0923b80b586a2ad2fa Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 18 Feb 2020 18:13:36 +0100 Subject: [PATCH] Warn if --audio_sample_rate does not match training sample In PR #2688, we started specifying the upper frequency limit when computing Mfccs. This value was computed as half of the --audio_sample_rate value. Despite accepting a variable sample rate input for the Mfcc computation, the TensorFlow OP only takes a constant upper frequency limit, so we can't pass a dynamic value computed from each sample to the op. This means we lost the ability to transparently train on data with multiple sample rates. This commit adds a warning message in case a training sample does not match the --audio_sample_rate flag. --- util/feeding.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/util/feeding.py b/util/feeding.py index 983fdb9c..2c33d2ae 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -30,7 +30,10 @@ def read_csvs(csv_files): return pandas.concat(sets, join='inner', ignore_index=True) -def samples_to_mfccs(samples, sample_rate, train_phase=False): +def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None): + if train_phase and sample_rate != FLAGS.audio_sample_rate: + tf.print('WARNING: sample rate of file', wav_filename, '(', sample_rate, ') does not match FLAGS.audio_sample_rate. This can lead to incorrect results.') + spectrogram = contrib_audio.audio_spectrogram(samples, window_size=Config.audio_window_samples, stride=Config.audio_step_samples, @@ -79,7 +82,7 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False): def audiofile_to_features(wav_filename, train_phase=False): samples = tf.io.read_file(wav_filename) decoded = contrib_audio.decode_wav(samples, desired_channels=1) - features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase) + features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase, wav_filename=wav_filename) if train_phase: if FLAGS.data_aug_features_multiplicative > 0: