Merge pull request #2538 from tilmankamp/transcribe
Tool for bulk transcription
This commit is contained in:
commit
f3d69147fe
8
DeepSpeech.py
Executable file → Normal file
8
DeepSpeech.py
Executable file → Normal file
@ -394,7 +394,7 @@ def log_grads_and_vars(grads_and_vars):
|
|||||||
log_variable(variable, gradient=gradient)
|
log_variable(variable, gradient=gradient)
|
||||||
|
|
||||||
|
|
||||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
|
||||||
try:
|
try:
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||||
if not checkpoint:
|
if not checkpoint:
|
||||||
@ -403,8 +403,10 @@ def try_loading(session, saver, checkpoint_filename, caption, load_step=True):
|
|||||||
saver.restore(session, checkpoint_path)
|
saver.restore(session, checkpoint_path)
|
||||||
if load_step:
|
if load_step:
|
||||||
restored_step = session.run(tfv1.train.get_global_step())
|
restored_step = session.run(tfv1.train.get_global_step())
|
||||||
log_info('Restored variables from %s checkpoint at %s, step %d' % (caption, checkpoint_path, restored_step))
|
if log_success:
|
||||||
else:
|
log_info('Restored variables from %s checkpoint at %s, step %d' %
|
||||||
|
(caption, checkpoint_path, restored_step))
|
||||||
|
elif log_success:
|
||||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||||
return True
|
return True
|
||||||
except tf.errors.InvalidArgumentError as e:
|
except tf.errors.InvalidArgumentError as e:
|
||||||
|
@ -11,6 +11,9 @@ absl-py
|
|||||||
# Requirements for building native_client files
|
# Requirements for building native_client files
|
||||||
setuptools
|
setuptools
|
||||||
|
|
||||||
|
# Requirements for transcribe.py
|
||||||
|
webrtcvad
|
||||||
|
|
||||||
# Requirements for importers
|
# Requirements for importers
|
||||||
sox
|
sox
|
||||||
bs4
|
bs4
|
||||||
|
152
transcribe.py
Executable file
152
transcribe.py
Executable file
@ -0,0 +1,152 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1.logging as tflogging
|
||||||
|
tflogging.set_verbosity(tflogging.ERROR)
|
||||||
|
import logging
|
||||||
|
logging.getLogger('sox').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
from multiprocessing import Process, cpu_count
|
||||||
|
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||||
|
from util.config import Config, initialize_globals
|
||||||
|
from util.audio import AudioFile
|
||||||
|
from util.feeding import split_audio_file
|
||||||
|
from util.flags import create_flags, FLAGS
|
||||||
|
from util.logging import log_error, log_info, log_progress, create_progressbar
|
||||||
|
|
||||||
|
|
||||||
|
def fail(message, code=1):
|
||||||
|
log_error(message)
|
||||||
|
sys.exit(code)
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_file(audio_path, tlog_path):
|
||||||
|
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
|
initialize_globals()
|
||||||
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet)
|
||||||
|
try:
|
||||||
|
num_processes = cpu_count()
|
||||||
|
except NotImplementedError:
|
||||||
|
num_processes = 1
|
||||||
|
with AudioFile(audio_path, as_path=True) as wav_path:
|
||||||
|
data_set = split_audio_file(wav_path,
|
||||||
|
batch_size=FLAGS.batch_size,
|
||||||
|
aggressiveness=FLAGS.vad_aggressiveness,
|
||||||
|
outlier_duration_ms=FLAGS.outlier_duration_ms,
|
||||||
|
outlier_batch_size=FLAGS.outlier_batch_size)
|
||||||
|
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes,
|
||||||
|
output_classes=data_set.output_classes)
|
||||||
|
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
|
||||||
|
no_dropout = [None] * 6
|
||||||
|
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout)
|
||||||
|
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||||
|
tf.train.get_or_create_global_step()
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
with tf.Session(config=Config.session_config) as session:
|
||||||
|
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', log_success=False)
|
||||||
|
if not loaded:
|
||||||
|
loaded = try_loading(session, saver, 'checkpoint', 'most recent', log_success=False)
|
||||||
|
if not loaded:
|
||||||
|
fail('Checkpoint directory ({}) does not contain a valid checkpoint state.'
|
||||||
|
.format(FLAGS.checkpoint_dir))
|
||||||
|
session.run(iterator.make_initializer(data_set))
|
||||||
|
transcripts = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
starts, ends, batch_logits, batch_lengths = \
|
||||||
|
session.run([batch_time_start, batch_time_end, transposed, batch_x_len])
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
break
|
||||||
|
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width,
|
||||||
|
num_processes=num_processes,
|
||||||
|
scorer=scorer)
|
||||||
|
decoded = list(d[0][1] for d in decoded)
|
||||||
|
transcripts.extend(zip(starts, ends, decoded))
|
||||||
|
transcripts.sort(key=lambda t: t[0])
|
||||||
|
transcripts = [{'start': int(start),
|
||||||
|
'end': int(end),
|
||||||
|
'transcript': transcript} for start, end, transcript in transcripts]
|
||||||
|
with open(tlog_path, 'w') as tlog_file:
|
||||||
|
json.dump(transcripts, tlog_file, default=float)
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_many(path_pairs):
|
||||||
|
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(path_pairs)).start()
|
||||||
|
for i, (src_path, dst_path) in enumerate(path_pairs):
|
||||||
|
p = Process(target=transcribe_file, args=(src_path, dst_path))
|
||||||
|
p.start()
|
||||||
|
p.join()
|
||||||
|
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(path_pairs), src_path, dst_path))
|
||||||
|
pbar.update(i)
|
||||||
|
pbar.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe_one(src_path, dst_path):
|
||||||
|
transcribe_file(src_path, dst_path)
|
||||||
|
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
|
||||||
|
|
||||||
|
|
||||||
|
def resolve(base_path, spec_path):
|
||||||
|
if spec_path is None:
|
||||||
|
return None
|
||||||
|
if not os.path.isabs(spec_path):
|
||||||
|
spec_path = os.path.join(base_path, spec_path)
|
||||||
|
return spec_path
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
if not FLAGS.src:
|
||||||
|
fail('You have to specify which file or catalog to transcribe via the --src flag.')
|
||||||
|
src_path = os.path.abspath(FLAGS.src)
|
||||||
|
if not os.path.isfile(src_path):
|
||||||
|
fail('Path in --src not existing')
|
||||||
|
if src_path.endswith('.catalog'):
|
||||||
|
if FLAGS.dst:
|
||||||
|
fail('Parameter --dst not supported if --src points to a catalog')
|
||||||
|
catalog_dir = os.path.dirname(src_path)
|
||||||
|
with open(src_path, 'r') as catalog_file:
|
||||||
|
catalog_entries = json.load(catalog_file)
|
||||||
|
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries]
|
||||||
|
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
|
||||||
|
fail('Missing source file(s) in catalog')
|
||||||
|
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)):
|
||||||
|
fail('Destination file(s) from catalog already existing, use --force for overwriting')
|
||||||
|
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)):
|
||||||
|
fail('Missing destination directory for at least one catalog entry')
|
||||||
|
transcribe_many(catalog_entries)
|
||||||
|
else:
|
||||||
|
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog'
|
||||||
|
if os.path.isfile(dst_path):
|
||||||
|
if FLAGS.force:
|
||||||
|
transcribe_one(src_path, dst_path)
|
||||||
|
else:
|
||||||
|
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0)
|
||||||
|
elif os.path.isdir(os.path.dirname(dst_path)):
|
||||||
|
transcribe_one(src_path, dst_path)
|
||||||
|
else:
|
||||||
|
fail('Missing destination directory')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
create_flags()
|
||||||
|
tf.app.flags.DEFINE_string('src', '', 'source path to an audio file or directory to recursively scan '
|
||||||
|
'for audio files. If --dst not set, transcription logs (.tlog) will be '
|
||||||
|
'written in-place using the source filenames with '
|
||||||
|
'suffix ".tlog" instead of ".wav".')
|
||||||
|
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). '
|
||||||
|
'If --src is a directory, this one also has to be a directory '
|
||||||
|
'and the required sub-dir tree of --src will get replicated.')
|
||||||
|
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing '
|
||||||
|
'transcription logs (.tlog)')
|
||||||
|
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should '
|
||||||
|
'split audio')
|
||||||
|
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size')
|
||||||
|
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers')
|
||||||
|
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)')
|
||||||
|
tf.app.run(main)
|
135
util/audio.py
Normal file
135
util/audio.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import os
|
||||||
|
import sox
|
||||||
|
import wave
|
||||||
|
import tempfile
|
||||||
|
import collections
|
||||||
|
from webrtcvad import Vad
|
||||||
|
|
||||||
|
DEFAULT_RATE = 16000
|
||||||
|
DEFAULT_CHANNELS = 1
|
||||||
|
DEFAULT_WIDTH = 2
|
||||||
|
DEFAULT_FORMAT = (DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_format(wav_file):
|
||||||
|
return wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_samples(audio_data, audio_format=DEFAULT_FORMAT):
|
||||||
|
_, channels, width = audio_format
|
||||||
|
return len(audio_data) // (channels * width)
|
||||||
|
|
||||||
|
|
||||||
|
def get_duration(audio_data, audio_format=DEFAULT_FORMAT):
|
||||||
|
return get_num_samples(audio_data, audio_format) / audio_format[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_duration_ms(audio_data, audio_format=DEFAULT_FORMAT):
|
||||||
|
return get_duration(audio_data, audio_format) * 1000
|
||||||
|
|
||||||
|
|
||||||
|
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT):
|
||||||
|
sample_rate, channels, width = audio_format
|
||||||
|
transformer = sox.Transformer()
|
||||||
|
transformer.set_output_format(file_type=file_type, rate=sample_rate, channels=channels, bits=width*8)
|
||||||
|
transformer.build(src_audio_path, dst_audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioFile:
|
||||||
|
def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT):
|
||||||
|
self.audio_path = audio_path
|
||||||
|
self.audio_format = audio_format
|
||||||
|
self.as_path = as_path
|
||||||
|
self.open_file = None
|
||||||
|
self.tmp_file_path = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.audio_path.endswith('.wav'):
|
||||||
|
self.open_file = wave.open(self.audio_path, 'r')
|
||||||
|
if get_audio_format(self.open_file) == self.audio_format:
|
||||||
|
if self.as_path:
|
||||||
|
self.open_file.close()
|
||||||
|
return self.audio_path
|
||||||
|
return self.open_file
|
||||||
|
self.open_file.close()
|
||||||
|
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
|
||||||
|
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
|
||||||
|
if self.as_path:
|
||||||
|
return self.tmp_file_path
|
||||||
|
self.open_file = wave.open(self.tmp_file_path, 'r')
|
||||||
|
return self.open_file
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
if not self.as_path:
|
||||||
|
self.open_file.close()
|
||||||
|
if self.tmp_file_path is not None:
|
||||||
|
os.remove(self.tmp_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
|
||||||
|
audio_format = get_audio_format(wav_file)
|
||||||
|
frame_size = int(audio_format[0] * (frame_duration_ms / 1000.0))
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = wav_file.readframes(frame_size)
|
||||||
|
if not yield_remainder and get_duration_ms(data, audio_format) < frame_duration_ms:
|
||||||
|
break
|
||||||
|
yield data
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False):
|
||||||
|
with AudioFile(audio_path, audio_format=audio_format) as wav_file:
|
||||||
|
for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder):
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
|
||||||
|
def vad_split(audio_frames,
|
||||||
|
audio_format=DEFAULT_FORMAT,
|
||||||
|
num_padding_frames=10,
|
||||||
|
threshold=0.5,
|
||||||
|
aggressiveness=3):
|
||||||
|
sample_rate, channels, width = audio_format
|
||||||
|
if channels != 1:
|
||||||
|
raise ValueError('VAD-splitting requires mono samples')
|
||||||
|
if width != 2:
|
||||||
|
raise ValueError('VAD-splitting requires 16 bit samples')
|
||||||
|
if sample_rate not in [8000, 16000, 32000, 48000]:
|
||||||
|
raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000')
|
||||||
|
if aggressiveness not in [0, 1, 2, 3]:
|
||||||
|
raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3')
|
||||||
|
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||||
|
triggered = False
|
||||||
|
vad = Vad(int(aggressiveness))
|
||||||
|
voiced_frames = []
|
||||||
|
frame_duration_ms = 0
|
||||||
|
frame_index = 0
|
||||||
|
for frame_index, frame in enumerate(audio_frames):
|
||||||
|
frame_duration_ms = get_duration_ms(frame, audio_format)
|
||||||
|
if int(frame_duration_ms) not in [10, 20, 30]:
|
||||||
|
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms')
|
||||||
|
is_speech = vad.is_speech(frame, sample_rate)
|
||||||
|
if not triggered:
|
||||||
|
ring_buffer.append((frame, is_speech))
|
||||||
|
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||||
|
if num_voiced > threshold * ring_buffer.maxlen:
|
||||||
|
triggered = True
|
||||||
|
for f, s in ring_buffer:
|
||||||
|
voiced_frames.append(f)
|
||||||
|
ring_buffer.clear()
|
||||||
|
else:
|
||||||
|
voiced_frames.append(frame)
|
||||||
|
ring_buffer.append((frame, is_speech))
|
||||||
|
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||||
|
if num_unvoiced > threshold * ring_buffer.maxlen:
|
||||||
|
triggered = False
|
||||||
|
yield b''.join(voiced_frames), \
|
||||||
|
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \
|
||||||
|
frame_duration_ms * frame_index
|
||||||
|
ring_buffer.clear()
|
||||||
|
voiced_frames = []
|
||||||
|
if len(voiced_frames) > 0:
|
||||||
|
yield b''.join(voiced_frames), \
|
||||||
|
frame_duration_ms * (frame_index - len(voiced_frames)), \
|
||||||
|
frame_duration_ms * (frame_index + 1)
|
@ -8,15 +8,15 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas
|
import pandas
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import datetime
|
|
||||||
|
|
||||||
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
||||||
|
|
||||||
from util.config import Config
|
from util.config import Config
|
||||||
from util.logging import log_error
|
|
||||||
from util.text import text_to_char_array
|
from util.text import text_to_char_array
|
||||||
from util.flags import FLAGS
|
from util.flags import FLAGS
|
||||||
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
|
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
|
||||||
|
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
|
||||||
|
|
||||||
|
|
||||||
def read_csvs(csv_files):
|
def read_csvs(csv_files):
|
||||||
sets = []
|
sets = []
|
||||||
@ -134,6 +134,46 @@ def create_dataset(csvs, batch_size, cache_path='', train_phase=False):
|
|||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def split_audio_file(audio_path,
|
||||||
|
audio_format=DEFAULT_FORMAT,
|
||||||
|
batch_size=1,
|
||||||
|
aggressiveness=3,
|
||||||
|
outlier_duration_ms=10000,
|
||||||
|
outlier_batch_size=1):
|
||||||
|
sample_rate, _, sample_width = audio_format
|
||||||
|
multiplier = 1.0 / (1 << (8 * sample_width - 1))
|
||||||
|
|
||||||
|
def generate_values():
|
||||||
|
frames = read_frames_from_file(audio_path)
|
||||||
|
segments = vad_split(frames, aggressiveness=aggressiveness)
|
||||||
|
for segment in segments:
|
||||||
|
segment_buffer, time_start, time_end = segment
|
||||||
|
samples = np.frombuffer(segment_buffer, dtype=np.int16)
|
||||||
|
samples = samples * multiplier
|
||||||
|
samples = np.expand_dims(samples, axis=1)
|
||||||
|
yield time_start, time_end, samples
|
||||||
|
|
||||||
|
def to_mfccs(time_start, time_end, samples):
|
||||||
|
features, features_len = samples_to_mfccs(samples, sample_rate)
|
||||||
|
return time_start, time_end, features, features_len
|
||||||
|
|
||||||
|
def create_batch_set(bs, criteria):
|
||||||
|
return (tf.data.Dataset
|
||||||
|
.from_generator(generate_values, output_types=(tf.int32, tf.int32, tf.float32))
|
||||||
|
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||||
|
.filter(criteria)
|
||||||
|
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], [])))
|
||||||
|
|
||||||
|
nds = create_batch_set(batch_size,
|
||||||
|
lambda start, end, f, fl: end - start <= int(outlier_duration_ms))
|
||||||
|
ods = create_batch_set(outlier_batch_size,
|
||||||
|
lambda start, end, f, fl: end - start > int(outlier_duration_ms))
|
||||||
|
dataset = nds.concatenate(ods)
|
||||||
|
dataset = dataset.prefetch(len(Config.available_devices))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def secs_to_hours(secs):
|
def secs_to_hours(secs):
|
||||||
hours, remainder = divmod(secs, 3600)
|
hours, remainder = divmod(secs, 3600)
|
||||||
minutes, seconds = divmod(remainder, 60)
|
minutes, seconds = divmod(remainder, 60)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user