From 904ab1e288ce2441dc8ad2e48adf7e0b6e8983ca Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 11 Apr 2019 16:26:11 -0300 Subject: [PATCH] Centralize progress logging and progress bar logic --- DeepSpeech.py | 13 ++----------- evaluate.py | 14 +++++++------- util/logging.py | 19 +++++++++++++++++++ 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 15f61c67..37cf222a 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -22,7 +22,7 @@ from tensorflow.python.tools import freeze_graph from util.config import Config, initialize_globals from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.flags import create_flags, FLAGS -from util.logging import log_info, log_error, log_debug +from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar # Graph Creation @@ -425,15 +425,6 @@ def train(): initializer = tf.global_variables_initializer() - # Disable progress logging if needed - if FLAGS.show_progressbar: - pbar_class = progressbar.ProgressBar - def log_progress(*args, **kwargs): - pass - else: - pbar_class = progressbar.NullBar - log_progress = log_info - with tf.Session(config=Config.session_config) as session: log_debug('Session opened.') @@ -479,7 +470,7 @@ def train(): ' | Steps: ', progressbar.widgets.Counter(), ' | ', LossWidget()] suffix = ' | Dataset: {}'.format(dataset) if dataset else None - pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start() + pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() # Initialize iterator to the appropriate dataset session.run(init_op) diff --git a/evaluate.py b/evaluate.py index 203bcfea..2dc767f8 100755 --- a/evaluate.py +++ b/evaluate.py @@ -18,7 +18,7 @@ from util.config import Config, initialize_globals from util.evaluate_tools import calculate_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS -from util.logging import log_error +from util.logging import log_error, log_progress, create_progressbar from util.text import levenshtein @@ -93,9 +93,9 @@ def evaluate(test_csvs, create_model, try_loading): seq_lengths = [] ground_truths = [] - bar = progressbar.ProgressBar(prefix='Computing acoustic model predictions | ', - widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()], - fd=sys.stdout).start() + bar = create_progressbar(prefix='Computing acoustic model predictions | ', + widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start() + log_progress('Computing acoustic model predictions...') step_count = 0 @@ -121,9 +121,9 @@ def evaluate(test_csvs, create_model, try_loading): predictions = [] - bar = progressbar.ProgressBar(max_value=step_count, - prefix='Decoding predictions | ', - fd=sys.stdout).start() + bar = create_progressbar(max_value=step_count, + prefix='Decoding predictions | ').start() + log_progress('Decoding predictions...') # Second pass, decode logits and compute WER and edit distance metrics for logits, seq_length in bar(zip(logitses, seq_lengths)): diff --git a/util/logging.py b/util/logging.py index b6f9ffb9..c7643a44 100644 --- a/util/logging.py +++ b/util/logging.py @@ -1,5 +1,8 @@ from __future__ import print_function +import progressbar +import sys + from util.flags import FLAGS @@ -28,3 +31,19 @@ def log_warn(message): def log_error(message): if FLAGS.log_level <= 3: prefix_print('E ', message) + + +def create_progressbar(*args, **kwargs): + # Progress bars in stdout by default + if 'fd' not in kwargs: + kwargs['fd'] = sys.stdout + + if FLAGS.show_progressbar: + return progressbar.ProgressBar(*args, **kwargs) + + return progressbar.NullBar(*args, **kwargs) + + +def log_progress(message): + if not FLAGS.show_progressbar: + log_info(message)