Separate process per file; less log noise

This commit is contained in:
Tilman Kamp 2019-11-20 17:29:13 +01:00
parent c24c510fd9
commit 29528ed7b7
3 changed files with 66 additions and 77 deletions

8
DeepSpeech.py Executable file → Normal file
View File

@ -394,7 +394,7 @@ def log_grads_and_vars(grads_and_vars):
log_variable(variable, gradient=gradient)
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
if not checkpoint:
@ -403,8 +403,10 @@ def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
saver.restore(session, checkpoint_path)
if load_step:
restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
else:
if log_success:
log_info('Restored variables from %s checkpoint at %s, step %d' %
(caption, checkpoint_path, restored_step))
elif log_success:
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
return True
except tf.errors.InvalidArgumentError as e:

View File

@ -3,12 +3,16 @@
from __future__ import absolute_import, division, print_function
import os
import gc
import sys
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow.compat.v1.logging as tflogging
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)
from multiprocessing import cpu_count
from multiprocessing import Process, cpu_count
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import Config, initialize_globals
from util.audio import AudioFile
@ -17,44 +21,41 @@ from util.flags import create_flags, FLAGS
from util.logging import log_error, log_info, log_progress, create_progressbar
def split_audio_file_flags(audio_file):
return split_audio_file(audio_file,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size)
def fail(message, code=1):
log_error(message)
sys.exit(code)
def transcribe(path_pairs, create_model, try_loading):
def transcribe_file(audio_path, tlog_path):
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet)
audio_path, _ = path_pairs[0]
data_set = split_audio_file_flags(None)
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
output_classes=data_set.output_classes)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
saver = tf.train.Saver()
with tf.Session(config=Config.session_config) as session:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
if not loaded:
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
if not loaded:
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'
.format(FLAGS.checkpoint_dir))
sys.exit(1)
def run_transcription(p_index, p_data_set, p_audio_path, p_tlog_path):
bar = create_progressbar(prefix='Transcribing file {} "{}" | '.format(p_index, p_audio_path)).start()
log_progress('Transcribing file {}, "{}"...'.format(p_index, p_audio_path))
session.run(iterator.make_initializer(p_data_set))
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size)
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
output_classes=data_set.output_classes)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
saver = tf.train.Saver()
with tf.Session(config=Config.session_config) as session:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
if not loaded:
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
if not loaded:
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
.format(FLAGS.checkpoint_dir))
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
@ -67,21 +68,28 @@ def transcribe(path_pairs, create_model, try_loading):
scorer=scorer)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
bar.update(len(transcripts))
bar.finish()
transcripts.sort(key=lambda t: t[0])
transcripts = [{'start': int(start),
'end': int(end),
'transcript': transcript} for start, end, transcript in transcripts]
log_info('Writing transcript log to "{}"...'.format(p_tlog_path))
with open(p_tlog_path, 'w') as tlog_file:
with open(tlog_path, 'w') as tlog_file:
json.dump(transcripts, tlog_file, default=float)
for index, (audio_path, tlog_path) in enumerate(path_pairs):
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file_flags(wav_path)
run_transcription(index, data_set, audio_path, tlog_path)
gc.collect()
def transcribe_many(path_pairs):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
for i, (src_path, dst_path) in enumerate(path_pairs):
p = Process(target=transcribe_file, args=(src_path, dst_path))
p.start()
p.join()
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
pbar.update(i)
pbar.finish()
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
def resolve(base_path, spec_path):
@ -93,49 +101,36 @@ def resolve(base_path, spec_path):
def main(_):
initialize_globals()
if not FLAGS.src:
log_error('You have to specify which file or catalog to transcribe via the --src flag.')
sys.exit(1)
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
fail('You have to specify which file or catalog to transcribe via the --src flag.')
src_path = os.path.abspath(FLAGS.src)
if not os.path.isfile(src_path):
log_error('Path in --src not existing')
sys.exit(1)
fail('Path in --src not existing')
if src_path.endswith('.catalog'):
if FLAGS.dst:
log_error('Parameter --dst not supported if --src points to a catalog')
sys.exit(1)
fail('Parameter --dst not supported if --src points to a catalog')
catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
log_error('Missing source file(s) in catalog')
sys.exit(1)
fail('Missing source file(s) in catalog')
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
log_error('Destination file(s) from catalog already existing, use --force for overwriting')
sys.exit(1)
fail('Destination file(s) from catalog already existing, use --force for overwriting')
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
log_error('Missing destination directory for at least one catalog entry')
sys.exit(1)
transcribe(catalog_entries, create_model, try_loading)
fail('Missing destination directory for at least one catalog entry')
transcribe_many(catalog_entries)
else:
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe([(src_path, dst_path)], create_model, try_loading)
transcribe_one(src_path, dst_path)
else:
log_error('Destination file "{}" already existing - requires --force for overwriting'.format(dst_path))
sys.exit(0)
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe([(src_path, dst_path)], create_model, try_loading)
transcribe_one(src_path, dst_path)
else:
log_error('Missing destination directory')
sys.exit(1)
fail('Missing destination directory')
if __name__ == '__main__':

View File

@ -1,6 +1,5 @@
import os
import sox
import time
import wave
import tempfile
import collections
@ -55,13 +54,6 @@ class AudioFile:
self.open_file.close()
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
retry_count = 10
while retry_count > 0 and not (os.path.exists(self.tmp_file_path) and os.path.getsize(self.tmp_file_path) > 0):
retry_count -= 1
if retry_count == 0:
raise RuntimeError('Unable to read temporary .wav file')
time.sleep(1)
print('Trying to read temporary .wav file...')
if self.as_path:
return self.tmp_file_path
self.open_file = wave.open(self.tmp_file_path, 'r')