Merge pull request #2538 from tilmankamp/transcribe
Tool for bulk transcription
This commit is contained in:
commit
f3d69147fe
8
DeepSpeech.py
Executable file → Normal file
8
DeepSpeech.py
Executable file → Normal file
@ -394,7 +394,7 @@ def log_grads_and_vars(grads_and_vars):
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
|
||||
try:
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
@ -403,8 +403,10 @@ def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
||||
saver.restore(session, checkpoint_path)
|
||||
if load_step:
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
||||
else:
|
||||
if log_success:
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' %
|
||||
(caption, checkpoint_path, restored_step))
|
||||
elif log_success:
|
||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
|
@ -11,6 +11,9 @@ absl-py
|
||||
# Requirements for building native_client files
|
||||
setuptools
|
||||
|
||||
# Requirements for transcribe.py
|
||||
webrtcvad
|
||||
|
||||
# Requirements for importers
|
||||
sox
|
||||
bs4
|
||||
|
152
transcribe.py
Executable file
152
transcribe.py
Executable file
@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1.logging as tflogging
|
||||
tflogging.set_verbosity(tflogging.ERROR)
|
||||
import logging
|
||||
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||
|
||||
from multiprocessing import Process, cpu_count
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from util.config import Config, initialize_globals
|
||||
from util.audio import AudioFile
|
||||
from util.feeding import split_audio_file
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.logging import log_error, log_info, log_progress, create_progressbar
|
||||
|
||||
|
||||
def fail(message, code=1):
|
||||
log_error(message)
|
||||
sys.exit(code)
|
||||
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet)
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||
data_set = split_audio_file(wav_path,
|
||||
batch_size=FLAGS.batch_size,
|
||||
aggressiveness=FLAGS.vad_aggressiveness,
|
||||
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
|
||||
output_classes=data_set.output_classes)
|
||||
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||
no_dropout = [None] * 6
|
||||
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
saver = tf.train.Saver()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
|
||||
if not loaded:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
|
||||
if not loaded:
|
||||
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
|
||||
.format(FLAGS.checkpoint_dir))
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
transcripts = []
|
||||
while True:
|
||||
try:
|
||||
starts, ends, batch_logits, batch_lengths = \
|
||||
session.run([batch_time_start, batch_time_end, transposed, batch_x_len])
|
||||
except tf.errors.OutOfRangeError:
|
||||
break
|
||||
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||
num_processes=num_processes,
|
||||
scorer=scorer)
|
||||
decoded = list(d[0][1] for d in decoded)
|
||||
transcripts.extend(zip(starts, ends, decoded))
|
||||
transcripts.sort(key=lambda t: t[0])
|
||||
transcripts = [{'start': int(start),
|
||||
'end': int(end),
|
||||
'transcript': transcript} for start, end, transcript in transcripts]
|
||||
with open(tlog_path, 'w') as tlog_file:
|
||||
json.dump(transcripts, tlog_file, default=float)
|
||||
|
||||
|
||||
def transcribe_many(path_pairs):
|
||||
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
|
||||
for i, (src_path, dst_path) in enumerate(path_pairs):
|
||||
p = Process(target=transcribe_file, args=(src_path, dst_path))
|
||||
p.start()
|
||||
p.join()
|
||||
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
|
||||
pbar.update(i)
|
||||
pbar.finish()
|
||||
|
||||
|
||||
def transcribe_one(src_path, dst_path):
|
||||
transcribe_file(src_path, dst_path)
|
||||
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
|
||||
|
||||
|
||||
def resolve(base_path, spec_path):
|
||||
if spec_path is None:
|
||||
return None
|
||||
if not os.path.isabs(spec_path):
|
||||
spec_path = os.path.join(base_path, spec_path)
|
||||
return spec_path
|
||||
|
||||
|
||||
def main(_):
|
||||
if not FLAGS.src:
|
||||
fail('You have to specify which file or catalog to transcribe via the --src flag.')
|
||||
src_path = os.path.abspath(FLAGS.src)
|
||||
if not os.path.isfile(src_path):
|
||||
fail('Path in --src not existing')
|
||||
if src_path.endswith('.catalog'):
|
||||
if FLAGS.dst:
|
||||
fail('Parameter --dst not supported if --src points to a catalog')
|
||||
catalog_dir = os.path.dirname(src_path)
|
||||
with open(src_path, 'r') as catalog_file:
|
||||
catalog_entries = json.load(catalog_file)
|
||||
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
|
||||
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||
fail('Missing source file(s) in catalog')
|
||||
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
|
||||
fail('Destination file(s) from catalog already existing, use --force for overwriting')
|
||||
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
|
||||
fail('Missing destination directory for at least one catalog entry')
|
||||
transcribe_many(catalog_entries)
|
||||
else:
|
||||
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
|
||||
if os.path.isfile(dst_path):
|
||||
if FLAGS.force:
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
|
||||
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||
transcribe_one(src_path, dst_path)
|
||||
else:
|
||||
fail('Missing destination directory')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
tf.app.flags.DEFINE_string('src', '', 'source path to an audio file or directory to recursively scan '
|
||||
'for audio files. If --dst not set, transcription logs (.tlog) will be '
|
||||
'written in-place using the source filenames with '
|
||||
'suffix ".tlog" instead of ".wav".')
|
||||
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
|
||||
'If --src is a directory, this one also has to be a directory '
|
||||
'and the required sub-dir tree of --src will get replicated.')
|
||||
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
|
||||
'transcription logs (.tlog)')
|
||||
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
|
||||
'split audio')
|
||||
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
|
||||
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
|
||||
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
|
||||
tf.app.run(main)
|
135
util/audio.py
Normal file
135
util/audio.py
Normal file
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import sox
|
||||
import wave
|
||||
import tempfile
|
||||
import collections
|
||||
from webrtcvad import Vad
|
||||
|
||||
DEFAULT_RATE = 16000
|
||||
DEFAULT_CHANNELS = 1
|
||||
DEFAULT_WIDTH = 2
|
||||
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
|
||||
|
||||
|
||||
def get_audio_format(wav_file):
|
||||
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
|
||||
|
||||
|
||||
def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT):
|
||||
_, channels, width = audio_format
|
||||
return len(audio_data) // (channels * width)
|
||||
|
||||
|
||||
def get_duration(audio_data, audio_format=DEFAULT_FORMAT):
|
||||
return get_num_samples(audio_data, audio_format) / audio_format[0]
|
||||
|
||||
|
||||
def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT):
|
||||
return get_duration(audio_data, audio_format) * 1000
|
||||
|
||||
|
||||
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
|
||||
sample_rate, channels, width = audio_format
|
||||
transformer = sox.Transformer()
|
||||
transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8)
|
||||
transformer.build(src_audio_path, dst_audio_path)
|
||||
|
||||
|
||||
class AudioFile:
|
||||
def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT):
|
||||
self.audio_path = audio_path
|
||||
self.audio_format = audio_format
|
||||
self.as_path = as_path
|
||||
self.open_file = None
|
||||
self.tmp_file_path = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.audio_path.endswith('.wav'):
|
||||
self.open_file = wave.open(self.audio_path, 'r')
|
||||
if get_audio_format(self.open_file) == self.audio_format:
|
||||
if self.as_path:
|
||||
self.open_file.close()
|
||||
return self.audio_path
|
||||
return self.open_file
|
||||
self.open_file.close()
|
||||
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
|
||||
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
|
||||
if self.as_path:
|
||||
return self.tmp_file_path
|
||||
self.open_file = wave.open(self.tmp_file_path, 'r')
|
||||
return self.open_file
|
||||
|
||||
def __exit__(self, *args):
|
||||
if not self.as_path:
|
||||
self.open_file.close()
|
||||
if self.tmp_file_path is not None:
|
||||
os.remove(self.tmp_file_path)
|
||||
|
||||
|
||||
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
|
||||
audio_format = get_audio_format(wav_file)
|
||||
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
|
||||
while True:
|
||||
try:
|
||||
data = wav_file.readframes(frame_size)
|
||||
if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms:
|
||||
break
|
||||
yield data
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False):
|
||||
with AudioFile(audio_path, audio_format=audio_format) as wav_file:
|
||||
for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder):
|
||||
yield frame
|
||||
|
||||
|
||||
def vad_split(audio_frames,
|
||||
audio_format=DEFAULT_FORMAT,
|
||||
num_padding_frames=10,
|
||||
threshold=0.5,
|
||||
aggressiveness=3):
|
||||
sample_rate, channels, width = audio_format
|
||||
if channels != 1:
|
||||
raise ValueError('VAD-splitting requires mono samples')
|
||||
if width != 2:
|
||||
raise ValueError('VAD-splitting requires 16 bit samples')
|
||||
if sample_rate not in [8000, 16000, 32000, 48000]:
|
||||
raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000')
|
||||
if aggressiveness not in [0, 1, 2, 3]:
|
||||
raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3')
|
||||
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||
triggered = False
|
||||
vad = Vad(int(aggressiveness))
|
||||
voiced_frames = []
|
||||
frame_duration_ms = 0
|
||||
frame_index = 0
|
||||
for frame_index, frame in enumerate(audio_frames):
|
||||
frame_duration_ms = get_duration_ms(frame, audio_format)
|
||||
if int(frame_duration_ms) not in [10, 20, 30]:
|
||||
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
|
||||
is_speech = vad.is_speech(frame, sample_rate)
|
||||
if not triggered:
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||
if num_voiced > threshold * ring_buffer.maxlen:
|
||||
triggered = True
|
||||
for f, s in ring_buffer:
|
||||
voiced_frames.append(f)
|
||||
ring_buffer.clear()
|
||||
else:
|
||||
voiced_frames.append(frame)
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||
triggered = False
|
||||
yield b''.join(voiced_frames), \
|
||||
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * frame_index
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
if len(voiced_frames) > 0:
|
||||
yield b''.join(voiced_frames), \
|
||||
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||
frame_duration_ms * (frame_index + 1)
|
@ -8,15 +8,15 @@ from functools import partial
|
||||
import numpy as np
|
||||
import pandas
|
||||
import tensorflow as tf
|
||||
import datetime
|
||||
|
||||
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||
|
||||
from util.config import Config
|
||||
from util.logging import log_error
|
||||
from util.text import text_to_char_array
|
||||
from util.flags import FLAGS
|
||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
|
||||
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
|
||||
|
||||
|
||||
def read_csvs(csv_files):
|
||||
sets = []
|
||||
@ -134,6 +134,46 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def split_audio_file(audio_path,
|
||||
audio_format=DEFAULT_FORMAT,
|
||||
batch_size=1,
|
||||
aggressiveness=3,
|
||||
outlier_duration_ms=10000,
|
||||
outlier_batch_size=1):
|
||||
sample_rate, _, sample_width = audio_format
|
||||
multiplier = 1.0 / (1 << (8 * sample_width - 1))
|
||||
|
||||
def generate_values():
|
||||
frames = read_frames_from_file(audio_path)
|
||||
segments = vad_split(frames, aggressiveness=aggressiveness)
|
||||
for segment in segments:
|
||||
segment_buffer, time_start, time_end = segment
|
||||
samples = np.frombuffer(segment_buffer, dtype=np.int16)
|
||||
samples = samples * multiplier
|
||||
samples = np.expand_dims(samples, axis=1)
|
||||
yield time_start, time_end, samples
|
||||
|
||||
def to_mfccs(time_start, time_end, samples):
|
||||
features, features_len = samples_to_mfccs(samples, sample_rate)
|
||||
return time_start, time_end, features, features_len
|
||||
|
||||
def create_batch_set(bs, criteria):
|
||||
return (tf.data.Dataset
|
||||
.from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32))
|
||||
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
.filter(criteria)
|
||||
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], [])))
|
||||
|
||||
nds = create_batch_set(batch_size,
|
||||
lambda start, end, f, fl: end - start <= int(outlier_duration_ms))
|
||||
ods = create_batch_set(outlier_batch_size,
|
||||
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
|
||||
dataset = nds.concatenate(ods)
|
||||
dataset = dataset.prefetch(len(Config.available_devices))
|
||||
return dataset
|
||||
|
||||
|
||||
def secs_to_hours(secs):
|
||||
hours, remainder = divmod(secs, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
|
Loading…
x
Reference in New Issue
Block a user