Read audio format from data before running augmentation passes instead of assuming default

This commit is contained in:
Reuben Morais 2021-01-18 12:11:03 +00:00
parent 8c0d46cb7f
commit 79a42b345d
2 changed files with 48 additions and 2 deletions

View File

@ -76,6 +76,8 @@ class Sample:
if audio_type in SERIALIZABLE_AUDIO_TYPES: if audio_type in SERIALIZABLE_AUDIO_TYPES:
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
self.duration = read_duration(audio_type, self.audio) self.duration = read_duration(audio_type, self.audio)
if not self.audio_format:
self.audio_format = read_format(audio_type, self.audio)
else: else:
self.audio = raw_data self.audio = raw_data
if self.audio_format is None: if self.audio_format is None:
@ -521,6 +523,51 @@ def read_duration(audio_type, audio_file):
raise ValueError('Unsupported audio type: {}'.format(audio_type)) raise ValueError('Unsupported audio type: {}'.format(audio_type))
def read_wav_format(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
return read_audio_format_from_wav_file(wav_file_reader)
def read_opus_format(opus_file):
_, audio_format = read_opus_header(opus_file)
return audio_format
def read_ogg_opus_format(ogg_file):
error = ctypes.c_int()
ogg_file_buffer = ogg_file.getbuffer()
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer),
ctypes.pointer(error)
)
if error.value != 0:
raise ValueError(
("Ogg/Opus buffer could not be read."
"Error code: {}").format(error.value)
)
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
pyogg.opus.op_free(opusfile)
sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit
return AudioFormat(sample_rate, channel_count, sample_width)
def read_format(audio_type, audio_file):
if audio_type == AUDIO_TYPE_WAV:
return read_wav_format(audio_file)
if audio_type == AUDIO_TYPE_OPUS:
return read_opus_format(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus_format(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type))
def get_dtype(audio_format): def get_dtype(audio_format):
if audio_format.width not in [1, 2, 4]: if audio_format.width not in [1, 2, 4]:
raise ValueError('Unsupported sample width: {}'.format(audio_format.width)) raise ValueError('Unsupported sample width: {}'.format(audio_format.width))

View File

@ -11,7 +11,6 @@ from functools import partial
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
from .audio import ( from .audio import (
Sample, Sample,
DEFAULT_FORMAT,
AUDIO_TYPE_PCM, AUDIO_TYPE_PCM,
AUDIO_TYPE_OPUS, AUDIO_TYPE_OPUS,
SERIALIZABLE_AUDIO_TYPES, SERIALIZABLE_AUDIO_TYPES,
@ -40,7 +39,7 @@ CONTENT_TYPE_TRANSCRIPT = 'transcript'
class LabeledSample(Sample): class LabeledSample(Sample):
"""In-memory labeled audio sample representing an utterance. """In-memory labeled audio sample representing an utterance.
Derived from util.audio.Sample and used by sample collection readers and writers.""" Derived from util.audio.Sample and used by sample collection readers and writers."""
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None): def __init__(self, audio_type, raw_data, transcript, audio_format=None, sample_id=None):
""" """
Parameters Parameters
---------- ----------