Add support for Ogg/Opus audio files for training
This commit is contained in:
parent
ad0f7d2ab7
commit
d4152f6e67
@ -9,14 +9,14 @@ import sys
|
||||
import random
|
||||
import argparse
|
||||
|
||||
from deepspeech_training.util.audio import LOADABLE_AUDIO_EXTENSIONS, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
|
||||
from deepspeech_training.util.audio import get_loadable_audio_type_from_extension, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV
|
||||
from deepspeech_training.util.sample_collections import SampleList, LabeledSample, samples_from_source
|
||||
from deepspeech_training.util.augmentations import parse_augmentations, apply_sample_augmentations, SampleAugmentation
|
||||
|
||||
|
||||
def get_samples_in_play_order():
|
||||
ext = os.path.splitext(CLI_ARGS.source)[1].lower()
|
||||
if ext in LOADABLE_AUDIO_EXTENSIONS:
|
||||
if get_loadable_audio_type_from_extension(ext):
|
||||
samples = SampleList([(CLI_ARGS.source, 0)], labeled=False)
|
||||
else:
|
||||
samples = samples_from_source(CLI_ARGS.source, buffering=0)
|
||||
|
27
setup.py
27
setup.py
@ -50,22 +50,23 @@ def main():
|
||||
version = fin.read().strip()
|
||||
|
||||
install_requires_base = [
|
||||
'numpy',
|
||||
'progressbar2',
|
||||
'six',
|
||||
'pyxdg',
|
||||
'attrdict',
|
||||
'absl-py',
|
||||
'semver',
|
||||
'opuslib == 2.0.0',
|
||||
'optuna',
|
||||
'sox',
|
||||
'attrdict',
|
||||
'bs4',
|
||||
'pandas',
|
||||
'requests',
|
||||
'numba == 0.47.0', # ships py3.5 wheel
|
||||
'llvmlite == 0.31.0', # for numba==0.47.0
|
||||
'librosa',
|
||||
'llvmlite == 0.31.0', # for numba==0.47.0
|
||||
'numba == 0.47.0', # ships py3.5 wheel
|
||||
'numpy',
|
||||
'optuna',
|
||||
'opuslib == 2.0.0',
|
||||
'pandas',
|
||||
'progressbar2',
|
||||
'pyogg >= 0.6.14a1',
|
||||
'pyxdg',
|
||||
'requests',
|
||||
'semver',
|
||||
'six',
|
||||
'sox',
|
||||
'soundfile',
|
||||
]
|
||||
|
||||
|
@ -1,10 +1,12 @@
|
||||
import os
|
||||
import io
|
||||
import wave
|
||||
import math
|
||||
import tempfile
|
||||
import collections
|
||||
import ctypes
|
||||
import io
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import pyogg
|
||||
import tempfile
|
||||
import wave
|
||||
|
||||
from .helpers import LimitingPool
|
||||
from collections import namedtuple
|
||||
@ -21,8 +23,9 @@ AUDIO_TYPE_NP = 'application/vnd.mozilla.np'
|
||||
AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm'
|
||||
AUDIO_TYPE_WAV = 'audio/wav'
|
||||
AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus'
|
||||
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS]
|
||||
LOADABLE_AUDIO_EXTENSIONS = {'.wav': AUDIO_TYPE_WAV}
|
||||
AUDIO_TYPE_OGG_OPUS = 'application/vnd.deepspeech.ogg_opus'
|
||||
|
||||
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, AUDIO_TYPE_OGG_OPUS]
|
||||
|
||||
OPUS_PCM_LEN_SIZE = 4
|
||||
OPUS_RATE_SIZE = 4
|
||||
@ -133,10 +136,11 @@ def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None,
|
||||
yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples))
|
||||
|
||||
|
||||
def get_audio_type_from_extension(ext):
|
||||
if ext in LOADABLE_AUDIO_EXTENSIONS:
|
||||
return LOADABLE_AUDIO_EXTENSIONS[ext]
|
||||
return None
|
||||
def get_loadable_audio_type_from_extension(ext):
|
||||
return {
|
||||
'.wav': AUDIO_TYPE_WAV,
|
||||
'.opus': AUDIO_TYPE_OGG_OPUS,
|
||||
}.get(ext, None)
|
||||
|
||||
|
||||
def read_audio_format_from_wav_file(wav_file):
|
||||
@ -340,6 +344,102 @@ def read_opus(opus_file):
|
||||
return audio_format, bytes(audio_data)
|
||||
|
||||
|
||||
def read_ogg_opus(ogg_file):
|
||||
error = ctypes.c_int()
|
||||
ogg_file_buffer = ogg_file.getbuffer()
|
||||
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
|
||||
opusfile = pyogg.opus.op_open_memory(
|
||||
ubyte_array.from_buffer(ogg_file_buffer),
|
||||
len(ogg_file_buffer),
|
||||
ctypes.pointer(error)
|
||||
)
|
||||
|
||||
if error.value != 0:
|
||||
raise ValueError(
|
||||
("Ogg/Opus buffer could not be read."
|
||||
"Error code: {}").format(error.value)
|
||||
)
|
||||
|
||||
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
|
||||
sample_rate = 48000 # opus files are always 48kHz
|
||||
sample_width = 2 # always 16-bit
|
||||
audio_format = AudioFormat(sample_rate, channel_count, sample_width)
|
||||
|
||||
# Allocate sufficient memory to store the entire PCM
|
||||
pcm_size = pyogg.opus.op_pcm_total(opusfile, -1)
|
||||
Buf = pyogg.opus.opus_int16*(pcm_size*channel_count)
|
||||
buf = Buf()
|
||||
|
||||
# Create a pointer to the newly allocated memory. It
|
||||
# seems we can only do pointer arithmetic on void
|
||||
# pointers. See
|
||||
# https://mattgwwalker.wordpress.com/2020/05/30/pointer-manipulation-in-python/
|
||||
buf_ptr = ctypes.cast(
|
||||
ctypes.pointer(buf),
|
||||
ctypes.c_void_p
|
||||
)
|
||||
assert buf_ptr.value is not None # for mypy
|
||||
buf_ptr_zero = buf_ptr.value
|
||||
|
||||
#: Bytes per sample
|
||||
bytes_per_sample = ctypes.sizeof(pyogg.opus.opus_int16)
|
||||
|
||||
# Read through the entire file, copying the PCM into the
|
||||
# buffer
|
||||
samples = 0
|
||||
while True:
|
||||
# Calculate remaining buffer size
|
||||
remaining_buffer = (
|
||||
len(buf) # int
|
||||
- (buf_ptr.value - buf_ptr_zero) // bytes_per_sample
|
||||
)
|
||||
|
||||
# Convert buffer pointer to the desired type
|
||||
ptr = ctypes.cast(
|
||||
buf_ptr,
|
||||
ctypes.POINTER(pyogg.opus.opus_int16)
|
||||
)
|
||||
|
||||
# Read the next section of PCM
|
||||
ns = pyogg.opus.op_read(
|
||||
opusfile,
|
||||
ptr,
|
||||
remaining_buffer,
|
||||
pyogg.ogg.c_int_p()
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
if ns < 0:
|
||||
raise ValueError(
|
||||
"Error while reading OggOpus buffer. "+
|
||||
"Error code: {}".format(ns)
|
||||
)
|
||||
|
||||
# Increment the pointer
|
||||
buf_ptr.value += (
|
||||
ns
|
||||
* bytes_per_sample
|
||||
* channel_count
|
||||
)
|
||||
assert buf_ptr.value is not None # for mypy
|
||||
|
||||
samples += ns
|
||||
|
||||
# Check if we've finished
|
||||
if ns == 0:
|
||||
break
|
||||
|
||||
# Close the open file
|
||||
pyogg.opus.op_free(opusfile)
|
||||
|
||||
# Cast buffer to a one-dimensional array of chars
|
||||
#: Raw PCM data from audio file.
|
||||
CharBuffer = ctypes.c_byte * (bytes_per_sample * channel_count * pcm_size)
|
||||
audio_data = CharBuffer.from_buffer(buf)
|
||||
|
||||
return audio_format, audio_data
|
||||
|
||||
|
||||
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
|
||||
# wav_file is already a file-pointer here
|
||||
with wave.open(wav_file, 'wb') as wav_file_writer:
|
||||
@ -362,6 +462,8 @@ def read_audio(audio_type, audio_file):
|
||||
return read_wav(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OGG_OPUS:
|
||||
return read_ogg_opus(audio_file)
|
||||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
@ -384,11 +486,38 @@ def read_opus_duration(opus_file):
|
||||
return get_pcm_duration(pcm_buffer_size, audio_format)
|
||||
|
||||
|
||||
def read_ogg_opus_duration(ogg_file):
|
||||
error = ctypes.c_int()
|
||||
ogg_file_buffer = ogg_file.getbuffer()
|
||||
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
|
||||
opusfile = pyogg.opus.op_open_memory(
|
||||
ubyte_array.from_buffer(ogg_file_buffer),
|
||||
len(ogg_file_buffer),
|
||||
ctypes.pointer(error)
|
||||
)
|
||||
|
||||
if error.value != 0:
|
||||
raise ValueError(
|
||||
("Ogg/Opus buffer could not be read."
|
||||
"Error code: {}").format(error.value)
|
||||
)
|
||||
|
||||
pcm_buffer_size = pyogg.opus.op_pcm_total(opusfile, -1)
|
||||
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
|
||||
sample_rate = 48000 # opus files are always 48kHz
|
||||
sample_width = 2 # always 16-bit
|
||||
audio_format = AudioFormat(sample_rate, channel_count, sample_width)
|
||||
pyogg.opus.op_free(opusfile)
|
||||
return get_pcm_duration(pcm_buffer_size, audio_format)
|
||||
|
||||
|
||||
def read_duration(audio_type, audio_file):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return read_wav_duration(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus_duration(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OGG_OPUS:
|
||||
return read_ogg_opus_duration(audio_file)
|
||||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .audio import (
|
||||
AUDIO_TYPE_PCM,
|
||||
AUDIO_TYPE_OPUS,
|
||||
SERIALIZABLE_AUDIO_TYPES,
|
||||
get_audio_type_from_extension,
|
||||
get_loadable_audio_type_from_extension,
|
||||
write_wav
|
||||
)
|
||||
from .io import open_remote, is_remote_path
|
||||
@ -110,7 +110,7 @@ def load_sample(filename, label=None):
|
||||
util.audio.Sample instance if label is None, else util.sample_collections.LabeledSample instance
|
||||
"""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
audio_type = get_audio_type_from_extension(ext)
|
||||
audio_type = get_loadable_audio_type_from_extension(ext)
|
||||
if audio_type is None:
|
||||
raise ValueError('Unknown audio type extension "{}"'.format(ext))
|
||||
return PackedSample(filename, audio_type, label)
|
||||
|
Loading…
Reference in New Issue
Block a user