diff --git a/training/coqui_stt_training/evaluate.py b/training/coqui_stt_training/evaluate.py index 965b3370..0647705f 100755 --- a/training/coqui_stt_training/evaluate.py +++ b/training/coqui_stt_training/evaluate.py @@ -14,7 +14,7 @@ import tensorflow.compat.v1 as tfv1 from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip - +from .util.augmentations import NormalizeSampleRate from .util.config import Config, initialize_globals from .util.checkpoints import load_graph_for_evaluation from .util.evaluate_tools import calculate_and_print_report, save_samples_json @@ -53,6 +53,7 @@ def evaluate(test_csvs, create_model): test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False, + augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], reverse=FLAGS.reverse_test, limit=FLAGS.limit_test) for csv in test_csvs] iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index d8400982..def90001 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -28,6 +28,7 @@ from datetime import datetime from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from .evaluate import evaluate from six.moves import zip, range +from .util.augmentations import NormalizeSampleRate from .util.config import Config, initialize_globals from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint from .util.evaluate_tools import save_samples_json