Merge pull request #2538 from tilmankamp/transcribe

Tool for bulk transcription
This commit is contained in:
Tilman Kamp 2019-11-21 13:05:08 +01:00 committed by GitHub
commit f3d69147fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 337 additions and 5 deletions

8
DeepSpeech.py Executable file → Normal file
View File

@ -394,7 +394,7 @@ def log_grads_and_vars(grads_and_vars):
log_variable(variable, gradient=gradient)
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
if not checkpoint:
@ -403,8 +403,10 @@ def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
saver.restore(session, checkpoint_path)
if load_step:
restored_step = session.run(tfv1.train.get_global_step())
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
else:
if log_success:
log_info('Restored variables from %s checkpoint at %s, step %d' %
(caption, checkpoint_path, restored_step))
elif log_success:
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
return True
except tf.errors.InvalidArgumentError as e:

View File

@ -11,6 +11,9 @@ absl-py
# Requirements for building native_client files
setuptools
# Requirements for transcribe.py
webrtcvad
# Requirements for importers
sox
bs4

152
transcribe.py Executable file
View File

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import os
import sys
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import tensorflow.compat.v1.logging as tflogging
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger('sox').setLevel(logging.ERROR)
from multiprocessing import Process, cpu_count
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from util.config import Config, initialize_globals
from util.audio import AudioFile
from util.feeding import split_audio_file
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_info, log_progress, create_progressbar
def fail(message, code=1):
log_error(message)
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size)
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
output_classes=data_set.output_classes)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
saver = tf.train.Saver()
with tf.Session(config=Config.session_config) as session:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
if not loaded:
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
if not loaded:
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
.format(FLAGS.checkpoint_dir))
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = \
session.run([batch_time_start, batch_time_end, transposed, batch_x_len])
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [{'start': int(start),
'end': int(end),
'transcript': transcript} for start, end, transcript in transcripts]
with open(tlog_path, 'w') as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def transcribe_many(path_pairs):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
for i, (src_path, dst_path) in enumerate(path_pairs):
p = Process(target=transcribe_file, args=(src_path, dst_path))
p.start()
p.join()
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
pbar.update(i)
pbar.finish()
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def main(_):
if not FLAGS.src:
fail('You have to specify which file or catalog to transcribe via the --src flag.')
src_path = os.path.abspath(FLAGS.src)
if not os.path.isfile(src_path):
fail('Path in --src not existing')
if src_path.endswith('.catalog'):
if FLAGS.dst:
fail('Parameter --dst not supported if --src points to a catalog')
catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail('Missing source file(s) in catalog')
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
fail('Destination file(s) from catalog already existing, use --force for overwriting')
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
fail('Missing destination directory for at least one catalog entry')
transcribe_many(catalog_entries)
else:
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail('Missing destination directory')
if __name__ == '__main__':
create_flags()
tf.app.flags.DEFINE_string('src', '', 'source path to an audio file or directory to recursively scan '
'for audio files. If --dst not set, transcription logs (.tlog) will be '
'written in-place using the source filenames with '
'suffix ".tlog" instead of ".wav".')
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
'If --src is a directory, this one also has to be a directory '
'and the required sub-dir tree of --src will get replicated.')
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
'transcription logs (.tlog)')
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
'split audio')
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
tf.app.run(main)

135
util/audio.py Normal file
View File

@ -0,0 +1,135 @@
import os
import sox
import wave
import tempfile
import collections
from webrtcvad import Vad
DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1
DEFAULT_WIDTH = 2
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
def get_audio_format(wav_file):
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT):
_, channels, width = audio_format
return len(audio_data) // (channels * width)
def get_duration(audio_data, audio_format=DEFAULT_FORMAT):
return get_num_samples(audio_data, audio_format) / audio_format[0]
def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT):
return get_duration(audio_data, audio_format) * 1000
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
sample_rate, channels, width = audio_format
transformer = sox.Transformer()
transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8)
transformer.build(src_audio_path, dst_audio_path)
class AudioFile:
def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT):
self.audio_path = audio_path
self.audio_format = audio_format
self.as_path = as_path
self.open_file = None
self.tmp_file_path = None
def __enter__(self):
if self.audio_path.endswith('.wav'):
self.open_file = wave.open(self.audio_path, 'r')
if get_audio_format(self.open_file) == self.audio_format:
if self.as_path:
self.open_file.close()
return self.audio_path
return self.open_file
self.open_file.close()
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
if self.as_path:
return self.tmp_file_path
self.open_file = wave.open(self.tmp_file_path, 'r')
return self.open_file
def __exit__(self, *args):
if not self.as_path:
self.open_file.close()
if self.tmp_file_path is not None:
os.remove(self.tmp_file_path)
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
audio_format = get_audio_format(wav_file)
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
while True:
try:
data = wav_file.readframes(frame_size)
if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms:
break
yield data
except EOFError:
break
def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False):
with AudioFile(audio_path, audio_format=audio_format) as wav_file:
for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder):
yield frame
def vad_split(audio_frames,
audio_format=DEFAULT_FORMAT,
num_padding_frames=10,
threshold=0.5,
aggressiveness=3):
sample_rate, channels, width = audio_format
if channels != 1:
raise ValueError('VAD-splitting requires mono samples')
if width != 2:
raise ValueError('VAD-splitting requires 16 bit samples')
if sample_rate not in [8000, 16000, 32000, 48000]:
raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000')
if aggressiveness not in [0, 1, 2, 3]:
raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3')
ring_buffer = collections.deque(maxlen=num_padding_frames)
triggered = False
vad = Vad(int(aggressiveness))
voiced_frames = []
frame_duration_ms = 0
frame_index = 0
for frame_index, frame in enumerate(audio_frames):
frame_duration_ms = get_duration_ms(frame, audio_format)
if int(frame_duration_ms) not in [10, 20, 30]:
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
is_speech = vad.is_speech(frame, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
if num_voiced > threshold * ring_buffer.maxlen:
triggered = True
for f, s in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False
yield b''.join(voiced_frames), \
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
frame_duration_ms * frame_index
ring_buffer.clear()
voiced_frames = []
if len(voiced_frames) > 0:
yield b''.join(voiced_frames), \
frame_duration_ms * (frame_index - len(voiced_frames)), \
frame_duration_ms * (frame_index + 1)

View File

@ -8,15 +8,15 @@ from functools import partial
import numpy as np
import pandas
import tensorflow as tf
import datetime
from tensorflow.python.ops import gen_audio_ops as contrib_audio
from util.config import Config
from util.logging import log_error
from util.text import text_to_char_array
from util.flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
def read_csvs(csv_files):
sets = []
@ -134,6 +134,46 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
return dataset
def split_audio_file(audio_path,
audio_format=DEFAULT_FORMAT,
batch_size=1,
aggressiveness=3,
outlier_duration_ms=10000,
outlier_batch_size=1):
sample_rate, _, sample_width = audio_format
multiplier = 1.0 / (1 << (8 * sample_width - 1))
def generate_values():
frames = read_frames_from_file(audio_path)
segments = vad_split(frames, aggressiveness=aggressiveness)
for segment in segments:
segment_buffer, time_start, time_end = segment
samples = np.frombuffer(segment_buffer, dtype=np.int16)
samples = samples * multiplier
samples = np.expand_dims(samples, axis=1)
yield time_start, time_end, samples
def to_mfccs(time_start, time_end, samples):
features, features_len = samples_to_mfccs(samples, sample_rate)
return time_start, time_end, features, features_len
def create_batch_set(bs, criteria):
return (tf.data.Dataset
.from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32))
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.filter(criteria)
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], [])))
nds = create_batch_set(batch_size,
lambda start, end, f, fl: end - start <= int(outlier_duration_ms))
ods = create_batch_set(outlier_batch_size,
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
dataset = nds.concatenate(ods)
dataset = dataset.prefetch(len(Config.available_devices))
return dataset
def secs_to_hours(secs):
hours, remainder = divmod(secs, 3600)
minutes, seconds = divmod(remainder, 60)