Centralize progress logging and progress bar logic

This commit is contained in:
Reuben Morais 2019-04-11 16:26:11 -03:00
parent 9586fbbd30
commit 904ab1e288
3 changed files with 28 additions and 18 deletions

View File

@ -22,7 +22,7 @@ from tensorflow.python.tools import freeze_graph
from util.config import Config, initialize_globals
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.logging import log_info, log_error, log_debug
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
# Graph Creation
@ -425,15 +425,6 @@ def train():
initializer = tf.global_variables_initializer()
# Disable progress logging if needed
if FLAGS.show_progressbar:
pbar_class = progressbar.ProgressBar
def log_progress(*args, **kwargs):
pass
else:
pbar_class = progressbar.NullBar
log_progress = log_info
with tf.Session(config=Config.session_config) as session:
log_debug('Session opened.')
@ -479,7 +470,7 @@ def train():
' | Steps: ', progressbar.widgets.Counter(),
' | ', LossWidget()]
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start()
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
# Initialize iterator to the appropriate dataset
session.run(init_op)

View File

@ -18,7 +18,7 @@ from util.config import Config, initialize_globals
from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.logging import log_error, log_progress, create_progressbar
from util.text import levenshtein
@ -93,9 +93,9 @@ def evaluate(test_csvs, create_model, try_loading):
seq_lengths = []
ground_truths = []
bar = progressbar.ProgressBar(prefix='Computing acoustic model predictions | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()],
fd=sys.stdout).start()
bar = create_progressbar(prefix='Computing acoustic model predictions | ',
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
log_progress('Computing acoustic model predictions...')
step_count = 0
@ -121,9 +121,9 @@ def evaluate(test_csvs, create_model, try_loading):
predictions = []
bar = progressbar.ProgressBar(max_value=step_count,
prefix='Decoding predictions | ',
fd=sys.stdout).start()
bar = create_progressbar(max_value=step_count,
prefix='Decoding predictions | ').start()
log_progress('Decoding predictions...')
# Second pass, decode logits and compute WER and edit distance metrics
for logits, seq_length in bar(zip(logitses, seq_lengths)):

View File

@ -1,5 +1,8 @@
from __future__ import print_function
import progressbar
import sys
from util.flags import FLAGS
@ -28,3 +31,19 @@ def log_warn(message):
def log_error(message):
if FLAGS.log_level <= 3:
prefix_print('E ', message)
def create_progressbar(*args, **kwargs):
# Progress bars in stdout by default
if 'fd' not in kwargs:
kwargs['fd'] = sys.stdout
if FLAGS.show_progressbar:
return progressbar.ProgressBar(*args, **kwargs)
return progressbar.NullBar(*args, **kwargs)
def log_progress(message):
if not FLAGS.show_progressbar:
log_info(message)