diff --git a/DeepSpeech.py b/DeepSpeech.py old mode 100755 new mode 100644 index 7b176520..42c1d782 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -394,7 +394,7 @@ def log_grads_and_vars(grads_and_vars): log_variable(variable, gradient=gradient) -def try_loading(session, saver, checkpoint_filename, caption, load_step=True): +def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True): try: checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename) if not checkpoint: @@ -403,8 +403,10 @@ def try_loading(session, saver, checkpoint_filename, caption, load_step=True): saver.restore(session, checkpoint_path) if load_step: restored_step = session.run(tfv1.train.get_global_step()) - log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step)) - else: + if log_success: + log_info('Restored variables from %s checkpoint at %s, step %d' % + (caption, checkpoint_path, restored_step)) + elif log_success: log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path)) return True except tf.errors.InvalidArgumentError as e: diff --git a/transcribe.py b/transcribe.py index c07ae249..8c761a9a 100755 --- a/transcribe.py +++ b/transcribe.py @@ -3,12 +3,16 @@ from __future__ import absolute_import, division, print_function import os -import gc import sys import json +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf +import tensorflow.compat.v1.logging as tflogging +tflogging.set_verbosity(tflogging.ERROR) +import logging +logging.getLogger('sox').setLevel(logging.ERROR) -from multiprocessing import cpu_count +from multiprocessing import Process, cpu_count from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from util.config import Config, initialize_globals from util.audio import AudioFile @@ -17,44 +21,41 @@ from util.flags import create_flags, FLAGS from util.logging import log_error, log_info, log_progress, create_progressbar -def split_audio_file_flags(audio_file): - return split_audio_file(audio_file, - batch_size=FLAGS.batch_size, - aggressiveness=FLAGS.vad_aggressiveness, - outlier_duration_ms=FLAGS.outlier_duration_ms, - outlier_batch_size=FLAGS.outlier_batch_size) +def fail(message, code=1): + log_error(message) + sys.exit(code) -def transcribe(path_pairs, create_model, try_loading): +def transcribe_file(audio_path, tlog_path): + from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel + initialize_globals() scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) - audio_path, _ = path_pairs[0] - data_set = split_audio_file_flags(None) - iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes, - output_classes=data_set.output_classes) - batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next() - no_dropout = [None] * 6 - logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) - transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) - tf.train.get_or_create_global_step() try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 - saver = tf.train.Saver() - - with tf.Session(config=Config.session_config) as session: - loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') - if not loaded: - loaded = try_loading(session, saver, 'checkpoint', 'most recent') - if not loaded: - log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.' - .format(FLAGS.checkpoint_dir)) - sys.exit(1) - - def run_transcription(p_index, p_data_set, p_audio_path, p_tlog_path): - bar = create_progressbar(prefix='Transcribing file {} "{}" | '.format(p_index, p_audio_path)).start() - log_progress('Transcribing file {}, "{}"...'.format(p_index, p_audio_path)) - session.run(iterator.make_initializer(p_data_set)) + with AudioFile(audio_path, as_path=True) as wav_path: + data_set = split_audio_file(wav_path, + batch_size=FLAGS.batch_size, + aggressiveness=FLAGS.vad_aggressiveness, + outlier_duration_ms=FLAGS.outlier_duration_ms, + outlier_batch_size=FLAGS.outlier_batch_size) + iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes, + output_classes=data_set.output_classes) + batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next() + no_dropout = [None] * 6 + logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) + transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) + tf.train.get_or_create_global_step() + saver = tf.train.Saver() + with tf.Session(config=Config.session_config) as session: + loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False) + if not loaded: + loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False) + if not loaded: + fail('Checkpoint directory ({}) does not contain a valid checkpoint state.' + .format(FLAGS.checkpoint_dir)) + session.run(iterator.make_initializer(data_set)) transcripts = [] while True: try: @@ -67,21 +68,28 @@ def transcribe(path_pairs, create_model, try_loading): scorer=scorer) decoded = list(d[0][1] for d in decoded) transcripts.extend(zip(starts, ends, decoded)) - bar.update(len(transcripts)) - bar.finish() transcripts.sort(key=lambda t: t[0]) transcripts = [{'start': int(start), 'end': int(end), 'transcript': transcript} for start, end, transcript in transcripts] - log_info('Writing transcript log to "{}"...'.format(p_tlog_path)) - with open(p_tlog_path, 'w') as tlog_file: + with open(tlog_path, 'w') as tlog_file: json.dump(transcripts, tlog_file, default=float) - for index, (audio_path, tlog_path) in enumerate(path_pairs): - with AudioFile(audio_path, as_path=True) as wav_path: - data_set = split_audio_file_flags(wav_path) - run_transcription(index, data_set, audio_path, tlog_path) - gc.collect() + +def transcribe_many(path_pairs): + pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start() + for i, (src_path, dst_path) in enumerate(path_pairs): + p = Process(target=transcribe_file, args=(src_path, dst_path)) + p.start() + p.join() + log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path)) + pbar.update(i) + pbar.finish() + + +def transcribe_one(src_path, dst_path): + transcribe_file(src_path, dst_path) + log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path)) def resolve(base_path, spec_path): @@ -93,49 +101,36 @@ def resolve(base_path, spec_path): def main(_): - initialize_globals() - if not FLAGS.src: - log_error('You have to specify which file or catalog to transcribe via the --src flag.') - sys.exit(1) - - from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel - + fail('You have to specify which file or catalog to transcribe via the --src flag.') src_path = os.path.abspath(FLAGS.src) if not os.path.isfile(src_path): - log_error('Path in --src not existing') - sys.exit(1) + fail('Path in --src not existing') if src_path.endswith('.catalog'): if FLAGS.dst: - log_error('Parameter --dst not supported if --src points to a catalog') - sys.exit(1) + fail('Parameter --dst not supported if --src points to a catalog') catalog_dir = os.path.dirname(src_path) with open(src_path, 'r') as catalog_file: catalog_entries = json.load(catalog_file) catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries] if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)): - log_error('Missing source file(s) in catalog') - sys.exit(1) + fail('Missing source file(s) in catalog') if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)): - log_error('Destination file(s) from catalog already existing, use --force for overwriting') - sys.exit(1) + fail('Destination file(s) from catalog already existing, use --force for overwriting') if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)): - log_error('Missing destination directory for at least one catalog entry') - sys.exit(1) - transcribe(catalog_entries, create_model, try_loading) + fail('Missing destination directory for at least one catalog entry') + transcribe_many(catalog_entries) else: dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog' if os.path.isfile(dst_path): if FLAGS.force: - transcribe([(src_path, dst_path)], create_model, try_loading) + transcribe_one(src_path, dst_path) else: - log_error('Destination file "{}" already existing - requires --force for overwriting'.format(dst_path)) - sys.exit(0) + fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0) elif os.path.isdir(os.path.dirname(dst_path)): - transcribe([(src_path, dst_path)], create_model, try_loading) + transcribe_one(src_path, dst_path) else: - log_error('Missing destination directory') - sys.exit(1) + fail('Missing destination directory') if __name__ == '__main__': diff --git a/util/audio.py b/util/audio.py index 65d10e64..262f66a3 100644 --- a/util/audio.py +++ b/util/audio.py @@ -1,6 +1,5 @@ import os import sox -import time import wave import tempfile import collections @@ -55,13 +54,6 @@ class AudioFile: self.open_file.close() _, self.tmp_file_path = tempfile.mkstemp(suffix='.wav') convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format) - retry_count = 10 - while retry_count > 0 and not (os.path.exists(self.tmp_file_path) and os.path.getsize(self.tmp_file_path) > 0): - retry_count -= 1 - if retry_count == 0: - raise RuntimeError('Unable to read temporary .wav file') - time.sleep(1) - print('Trying to read temporary .wav file...') if self.as_path: return self.tmp_file_path self.open_file = wave.open(self.tmp_file_path, 'r')