diff --git a/util/feeding.py b/util/feeding.py index 983fdb9c..2c33d2ae 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -30,7 +30,10 @@ def read_csvs(csv_files): return pandas.concat(sets, join='inner', ignore_index=True) -def samples_to_mfccs(samples, sample_rate, train_phase=False): +def samples_to_mfccs(samples, sample_rate, train_phase=False, wav_filename=None): + if train_phase and sample_rate != FLAGS.audio_sample_rate: + tf.print('WARNING: sample rate of file', wav_filename, '(', sample_rate, ') does not match FLAGS.audio_sample_rate. This can lead to incorrect results.') + spectrogram = contrib_audio.audio_spectrogram(samples, window_size=Config.audio_window_samples, stride=Config.audio_step_samples, @@ -79,7 +82,7 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False): def audiofile_to_features(wav_filename, train_phase=False): samples = tf.io.read_file(wav_filename) decoded = contrib_audio.decode_wav(samples, desired_channels=1) - features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase) + features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase, wav_filename=wav_filename) if train_phase: if FLAGS.data_aug_features_multiplicative > 0: