Centralize progress logging and progress bar logic
This commit is contained in:
parent
9586fbbd30
commit
904ab1e288
@ -22,7 +22,7 @@ from tensorflow.python.tools import freeze_graph
|
||||
from util.config import Config, initialize_globals
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_info, log_error, log_debug
|
||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
|
||||
|
||||
# Graph Creation
|
||||
@ -425,15 +425,6 @@ def train():
|
||||
|
||||
initializer = tf.global_variables_initializer()
|
||||
|
||||
# Disable progress logging if needed
|
||||
if FLAGS.show_progressbar:
|
||||
pbar_class = progressbar.ProgressBar
|
||||
def log_progress(*args, **kwargs):
|
||||
pass
|
||||
else:
|
||||
pbar_class = progressbar.NullBar
|
||||
log_progress = log_info
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
@ -479,7 +470,7 @@ def train():
|
||||
' | Steps: ', progressbar.widgets.Counter(),
|
||||
' | ', LossWidget()]
|
||||
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
|
||||
pbar = pbar_class(prefix=prefix, widgets=widgets, suffix=suffix, fd=sys.stdout).start()
|
||||
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
|
||||
|
||||
# Initialize iterator to the appropriate dataset
|
||||
session.run(init_op)
|
||||
|
14
evaluate.py
14
evaluate.py
@ -18,7 +18,7 @@ from util.config import Config, initialize_globals
|
||||
from util.evaluate_tools import calculate_report
|
||||
from util.feeding import create_dataset
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error
|
||||
from util.logging import log_error, log_progress, create_progressbar
|
||||
from util.text import levenshtein
|
||||
|
||||
|
||||
@ -93,9 +93,9 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
seq_lengths = []
|
||||
ground_truths = []
|
||||
|
||||
bar = progressbar.ProgressBar(prefix='Computing acoustic model predictions | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()],
|
||||
fd=sys.stdout).start()
|
||||
bar = create_progressbar(prefix='Computing acoustic model predictions | ',
|
||||
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start()
|
||||
log_progress('Computing acoustic model predictions...')
|
||||
|
||||
step_count = 0
|
||||
|
||||
@ -121,9 +121,9 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
|
||||
predictions = []
|
||||
|
||||
bar = progressbar.ProgressBar(max_value=step_count,
|
||||
prefix='Decoding predictions | ',
|
||||
fd=sys.stdout).start()
|
||||
bar = create_progressbar(max_value=step_count,
|
||||
prefix='Decoding predictions | ').start()
|
||||
log_progress('Decoding predictions...')
|
||||
|
||||
# Second pass, decode logits and compute WER and edit distance metrics
|
||||
for logits, seq_length in bar(zip(logitses, seq_lengths)):
|
||||
|
@ -1,5 +1,8 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import progressbar
|
||||
import sys
|
||||
|
||||
from util.flags import FLAGS
|
||||
|
||||
|
||||
@ -28,3 +31,19 @@ def log_warn(message):
|
||||
def log_error(message):
|
||||
if FLAGS.log_level <= 3:
|
||||
prefix_print('E ', message)
|
||||
|
||||
|
||||
def create_progressbar(*args, **kwargs):
|
||||
# Progress bars in stdout by default
|
||||
if 'fd' not in kwargs:
|
||||
kwargs['fd'] = sys.stdout
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
return progressbar.ProgressBar(*args, **kwargs)
|
||||
|
||||
return progressbar.NullBar(*args, **kwargs)
|
||||
|
||||
|
||||
def log_progress(message):
|
||||
if not FLAGS.show_progressbar:
|
||||
log_info(message)
|
||||
|
Loading…
x
Reference in New Issue
Block a user