Separate process per file; less log noise
This commit is contained in:
parent
c24c510fd9
commit
29528ed7b7
8
DeepSpeech.py
Executable file → Normal file
8
DeepSpeech.py
Executable file → Normal file
@ -394,7 +394,7 @@ def log_grads_and_vars(grads_and_vars):
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
|
||||
try:
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
@ -403,8 +403,10 @@ def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||
saver.restore(session, checkpoint_path)
|
||||
if load_step:
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
else:
|
||||
if log_success:
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' %
|
||||
(caption, checkpoint_path, restored_step))
|
||||
elif log_success:
|
||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
|
||||
127
transcribe.py
127
transcribe.py
@ -3,12 +3,16 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import gc
|
||||
import sys
|
||||
import json
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1.logging as tflogging
|
||||
tflogging.set_verbosity(tflogging.ERROR)
|
||||
import logging
|
||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
from multiprocessing import Process, cpu_count
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from util.config import Config, initialize_globals
|
||||
from util.audio import AudioFile
|
||||
@ -17,44 +21,41 @@ from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error, log_info, log_progress, create_progressbar
|
||||
|
||||
|
||||
def split_audio_file_flags(audio_file):
|
||||
return split_audio_file(audio_file,
|
||||
batch_size=FLAGS.batch_size,
|
||||
aggressiveness=FLAGS.vad_aggressiveness,
|
||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||
def fail(message, code=1):
|
||||
log_error(message)
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
def transcribe(path_pairs, create_model, try_loading):
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet)
|
||||
audio_path, _ = path_pairs[0]
|
||||
data_set = split_audio_file_flags(None)
|
||||
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
|
||||
output_classes=data_set.output_classes)
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
||||
if not loaded:
|
||||
log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'
|
||||
.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
|
||||
def run_transcription(p_index, p_data_set, p_audio_path, p_tlog_path):
|
||||
bar = create_progressbar(prefix='Transcribing file {} "{}" | '.format(p_index, p_audio_path)).start()
|
||||
log_progress('Transcribing file {}, "{}"...'.format(p_index, p_audio_path))
|
||||
session.run(iterator.make_initializer(p_data_set))
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
data_set = split_audio_file(wav_path,
|
||||
batch_size=FLAGS.batch_size,
|
||||
aggressiveness=FLAGS.vad_aggressiveness,
|
||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
|
||||
output_classes=data_set.output_classes)
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
saver = tf.train.Saver()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
|
||||
if not loaded:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
|
||||
if not loaded:
|
||||
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
|
||||
.format(FLAGS.checkpoint_dir))
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
@ -67,21 +68,28 @@ def transcribe(path_pairs, create_model, try_loading):
|
||||
scorer=scorer)
|
||||
decoded = list(d[0][1] for d in decoded)
|
||||
transcripts.extend(zip(starts, ends, decoded))
|
||||
bar.update(len(transcripts))
|
||||
bar.finish()
|
||||
transcripts.sort(key=lambda t: t[0])
|
||||
transcripts = [{'start': int(start),
|
||||
'end': int(end),
|
||||
'transcript': transcript} for start, end, transcript in transcripts]
|
||||
log_info('Writing transcript log to "{}"...'.format(p_tlog_path))
|
||||
with open(p_tlog_path, 'w') as tlog_file:
|
||||
with open(tlog_path, 'w') as tlog_file:
|
||||
json.dump(transcripts, tlog_file, default=float)
|
||||
|
||||
for index, (audio_path, tlog_path) in enumerate(path_pairs):
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
data_set = split_audio_file_flags(wav_path)
|
||||
run_transcription(index, data_set, audio_path, tlog_path)
|
||||
gc.collect()
|
||||
|
||||
def transcribe_many(path_pairs):
|
||||
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
|
||||
for i, (src_path, dst_path) in enumerate(path_pairs):
|
||||
p = Process(target=transcribe_file, args=(src_path, dst_path))
|
||||
p.start()
|
||||
p.join()
|
||||
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
|
||||
pbar.update(i)
|
||||
pbar.finish()
|
||||
|
||||
|
||||
def transcribe_one(src_path, dst_path):
|
||||
transcribe_file(src_path, dst_path)
|
||||
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
|
||||
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
@ -93,49 +101,36 @@ def resolve(base_path, spec_path):
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.src:
|
||||
log_error('You have to specify which file or catalog to transcribe via the --src flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
|
||||
fail('You have to specify which file or catalog to transcribe via the --src flag.')
|
||||
src_path = os.path.abspath(FLAGS.src)
|
||||
if not os.path.isfile(src_path):
|
||||
log_error('Path in --src not existing')
|
||||
sys.exit(1)
|
||||
fail('Path in --src not existing')
|
||||
if src_path.endswith('.catalog'):
|
||||
if FLAGS.dst:
|
||||
log_error('Parameter --dst not supported if --src points to a catalog')
|
||||
sys.exit(1)
|
||||
fail('Parameter --dst not supported if --src points to a catalog')
|
||||
catalog_dir = os.path.dirname(src_path)
|
||||
with open(src_path, 'r') as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
|
||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||
log_error('Missing source file(s) in catalog')
|
||||
sys.exit(1)
|
||||
fail('Missing source file(s) in catalog')
|
||||
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
|
||||
log_error('Destination file(s) from catalog already existing, use --force for overwriting')
|
||||
sys.exit(1)
|
||||
fail('Destination file(s) from catalog already existing, use --force for overwriting')
|
||||
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
|
||||
log_error('Missing destination directory for at least one catalog entry')
|
||||
sys.exit(1)
|
||||
transcribe(catalog_entries, create_model, try_loading)
|
||||
fail('Missing destination directory for at least one catalog entry')
|
||||
transcribe_many(catalog_entries)
|
||||
else:
|
||||
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
|
||||
if os.path.isfile(dst_path):
|
||||
if FLAGS.force:
|
||||
transcribe([(src_path, dst_path)], create_model, try_loading)
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
log_error('Destination file "{}" already existing - requires --force for overwriting'.format(dst_path))
|
||||
sys.exit(0)
|
||||
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
|
||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||
transcribe([(src_path, dst_path)], create_model, try_loading)
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
log_error('Missing destination directory')
|
||||
sys.exit(1)
|
||||
fail('Missing destination directory')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import sox
|
||||
import time
|
||||
import wave
|
||||
import tempfile
|
||||
import collections
|
||||
@ -55,13 +54,6 @@ class AudioFile:
|
||||
self.open_file.close()
|
||||
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
|
||||
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
|
||||
retry_count = 10
|
||||
while retry_count > 0 and not (os.path.exists(self.tmp_file_path) and os.path.getsize(self.tmp_file_path) > 0):
|
||||
retry_count -= 1
|
||||
if retry_count == 0:
|
||||
raise RuntimeError('Unable to read temporary .wav file')
|
||||
time.sleep(1)
|
||||
print('Trying to read temporary .wav file...')
|
||||
if self.as_path:
|
||||
return self.tmp_file_path
|
||||
self.open_file = wave.open(self.tmp_file_path, 'r')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user