275 lines
8.7 KiB
Python
275 lines
8.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
from collections import Counter
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
|
|
|
import tensorflow as tf
|
|
|
|
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
|
|
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
|
|
from .config import Config
|
|
from .helpers import MEGABYTE
|
|
from .sample_collections import samples_from_sources
|
|
from .text import text_to_char_array
|
|
|
|
|
|
def audio_to_features(
|
|
audio,
|
|
sample_rate,
|
|
transcript=None,
|
|
clock=0.0,
|
|
train_phase=False,
|
|
augmentations=None,
|
|
sample_id=None,
|
|
):
|
|
if train_phase:
|
|
# We need the lambdas to make TensorFlow happy.
|
|
# pylint: disable=unnecessary-lambda
|
|
tf.cond(
|
|
tf.math.not_equal(sample_rate, Config.audio_sample_rate),
|
|
lambda: tf.print(
|
|
"WARNING: sample rate of sample",
|
|
sample_id,
|
|
"(",
|
|
sample_rate,
|
|
") "
|
|
"does not match Config.audio_sample_rate. This can lead to incorrect results.",
|
|
),
|
|
lambda: tf.no_op(),
|
|
name="matching_sample_rate",
|
|
)
|
|
|
|
if train_phase and augmentations:
|
|
audio = apply_graph_augmentations(
|
|
"signal", audio, augmentations, transcript=transcript, clock=clock
|
|
)
|
|
|
|
spectrogram = contrib_audio.audio_spectrogram(
|
|
audio,
|
|
window_size=Config.audio_window_samples,
|
|
stride=Config.audio_step_samples,
|
|
magnitude_squared=True,
|
|
)
|
|
|
|
if train_phase and augmentations:
|
|
spectrogram = apply_graph_augmentations(
|
|
"spectrogram",
|
|
spectrogram,
|
|
augmentations,
|
|
transcript=transcript,
|
|
clock=clock,
|
|
)
|
|
|
|
features = contrib_audio.mfcc(
|
|
spectrogram=spectrogram,
|
|
sample_rate=sample_rate,
|
|
dct_coefficient_count=Config.n_input,
|
|
upper_frequency_limit=Config.audio_sample_rate / 2,
|
|
)
|
|
features = tf.reshape(features, [-1, Config.n_input])
|
|
|
|
if train_phase and augmentations:
|
|
features = apply_graph_augmentations(
|
|
"features", features, augmentations, transcript=transcript, clock=clock
|
|
)
|
|
|
|
return features, tf.shape(input=features)[0]
|
|
|
|
|
|
def audiofile_to_features(
|
|
wav_filename, clock=0.0, train_phase=False, augmentations=None
|
|
):
|
|
samples = tf.io.read_file(wav_filename)
|
|
return wavfile_bytes_to_features(
|
|
samples, clock, train_phase, augmentations, sample_id=wav_filename
|
|
)
|
|
|
|
|
|
def wavfile_bytes_to_features(
|
|
samples, clock=0.0, train_phase=False, augmentations=None, sample_id=None
|
|
):
|
|
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
|
|
return audio_to_features(
|
|
decoded.audio,
|
|
decoded.sample_rate,
|
|
clock=clock,
|
|
train_phase=train_phase,
|
|
augmentations=augmentations,
|
|
sample_id=sample_id,
|
|
)
|
|
|
|
|
|
def entry_to_features(
|
|
sample_id,
|
|
audio,
|
|
sample_rate,
|
|
transcript,
|
|
clock,
|
|
train_phase=False,
|
|
augmentations=None,
|
|
):
|
|
# https://bugs.python.org/issue32117
|
|
sparse_transcript = tf.SparseTensor(*transcript)
|
|
features, features_len = audio_to_features(
|
|
audio,
|
|
sample_rate,
|
|
transcript=sparse_transcript,
|
|
clock=clock,
|
|
train_phase=train_phase,
|
|
augmentations=augmentations,
|
|
sample_id=sample_id,
|
|
)
|
|
return sample_id, features, features_len, sparse_transcript
|
|
|
|
|
|
def to_sparse_tuple(sequence):
|
|
r"""Creates a sparse representention of ``sequence``.
|
|
Returns a tuple with (indices, values, shape)
|
|
"""
|
|
indices = np.asarray(
|
|
list(zip([0] * len(sequence), range(len(sequence)))), dtype=np.int64
|
|
)
|
|
shape = np.asarray([1, len(sequence)], dtype=np.int64)
|
|
return indices, sequence, shape
|
|
|
|
|
|
def create_dataset(
|
|
sources,
|
|
batch_size,
|
|
epochs=1,
|
|
augmentations=None,
|
|
cache_path=None,
|
|
train_phase=False,
|
|
reverse=False,
|
|
limit=0,
|
|
process_ahead=None,
|
|
buffering=1 * MEGABYTE,
|
|
epoch_ph=None,
|
|
):
|
|
epoch_counter = Counter() # survives restarts of the dataset and its generator
|
|
|
|
def generate_values():
|
|
epoch = epoch_counter["epoch"]
|
|
if train_phase:
|
|
epoch_counter["epoch"] += 1
|
|
samples = samples_from_sources(
|
|
sources, buffering=buffering, labeled=True, reverse=reverse
|
|
)
|
|
num_samples = len(samples)
|
|
if limit > 0:
|
|
num_samples = min(limit, num_samples)
|
|
samples = apply_sample_augmentations(
|
|
samples,
|
|
augmentations,
|
|
buffering=buffering,
|
|
process_ahead=2 * batch_size if process_ahead is None else process_ahead,
|
|
clock=epoch / epochs,
|
|
final_clock=(epoch + 1) / epochs,
|
|
)
|
|
for sample_index, sample in enumerate(samples):
|
|
if sample_index >= num_samples:
|
|
break
|
|
clock = (
|
|
(epoch * num_samples + sample_index) / (epochs * num_samples)
|
|
if train_phase and epochs > 0
|
|
else 0.0
|
|
)
|
|
transcript = text_to_char_array(
|
|
sample.transcript, Config.alphabet, context=sample.sample_id
|
|
)
|
|
transcript = to_sparse_tuple(transcript)
|
|
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript, clock
|
|
|
|
# Batching a dataset of 2D SparseTensors creates 3D batches, which fail
|
|
# when passed to tf.nn.ctc_loss, so we reshape them to remove the extra
|
|
# dimension here.
|
|
def sparse_reshape(sparse):
|
|
shape = sparse.dense_shape
|
|
return tf.sparse.reshape(sparse, [shape[0], shape[2]])
|
|
|
|
def batch_fn(sample_ids, features, features_len, transcripts):
|
|
features = tf.data.Dataset.zip((features, features_len))
|
|
features = features.padded_batch(
|
|
batch_size, padded_shapes=([None, Config.n_input], [])
|
|
)
|
|
transcripts = transcripts.batch(batch_size).map(sparse_reshape)
|
|
sample_ids = sample_ids.batch(batch_size)
|
|
return tf.data.Dataset.zip((sample_ids, features, transcripts))
|
|
|
|
process_fn = partial(
|
|
entry_to_features, train_phase=train_phase, augmentations=augmentations
|
|
)
|
|
|
|
dataset = tf.data.Dataset.from_generator(
|
|
generate_values,
|
|
output_types=(
|
|
tf.string,
|
|
tf.float32,
|
|
tf.int32,
|
|
(tf.int64, tf.int32, tf.int64),
|
|
tf.float64,
|
|
),
|
|
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
|
if cache_path:
|
|
dataset = dataset.cache(cache_path)
|
|
dataset = dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)
|
|
|
|
if Config.shuffle_batches and epoch_ph is not None:
|
|
with tf.control_dependencies([tf.print("epoch:", epoch_ph)]):
|
|
epoch_buffer_size = tf.cond(
|
|
tf.less(epoch_ph, Config.shuffle_start),
|
|
lambda: tf.constant(1, tf.int64),
|
|
lambda: tf.constant(Config.shuffle_buffer, tf.int64),
|
|
)
|
|
dataset = dataset.shuffle(epoch_buffer_size, seed=epoch_ph)
|
|
|
|
dataset = dataset.prefetch(len(Config.available_devices))
|
|
return dataset
|
|
|
|
|
|
def split_audio_file(
|
|
audio_path,
|
|
audio_format=DEFAULT_FORMAT,
|
|
batch_size=1,
|
|
aggressiveness=3,
|
|
outlier_duration_ms=10000,
|
|
outlier_batch_size=1,
|
|
):
|
|
def generate_values():
|
|
frames = read_frames_from_file(audio_path)
|
|
segments = vad_split(frames, aggressiveness=aggressiveness)
|
|
for segment in segments:
|
|
segment_buffer, time_start, time_end = segment
|
|
samples = pcm_to_np(segment_buffer, audio_format)
|
|
yield time_start, time_end, samples
|
|
|
|
def to_mfccs(time_start, time_end, samples):
|
|
features, features_len = audio_to_features(samples, audio_format.rate)
|
|
return time_start, time_end, features, features_len
|
|
|
|
def create_batch_set(bs, criteria):
|
|
return (
|
|
tf.data.Dataset.from_generator(
|
|
generate_values,
|
|
output_types=(tf.int32, tf.int32, tf.float32),
|
|
)
|
|
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
|
.filter(criteria)
|
|
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))
|
|
)
|
|
|
|
nds = create_batch_set(
|
|
batch_size, lambda start, end, f, fl: end - start <= int(outlier_duration_ms)
|
|
)
|
|
ods = create_batch_set(
|
|
outlier_batch_size,
|
|
lambda start, end, f, fl: end - start > int(outlier_duration_ms),
|
|
)
|
|
dataset = nds.concatenate(ods)
|
|
dataset = dataset.prefetch(len(Config.available_devices))
|
|
return dataset
|