Read audio format from data before running augmentation passes instead of assuming default
This commit is contained in:
parent
8c0d46cb7f
commit
79a42b345d
|
@ -76,6 +76,8 @@ class Sample:
|
|||
if audio_type in SERIALIZABLE_AUDIO_TYPES:
|
||||
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
|
||||
self.duration = read_duration(audio_type, self.audio)
|
||||
if not self.audio_format:
|
||||
self.audio_format = read_format(audio_type, self.audio)
|
||||
else:
|
||||
self.audio = raw_data
|
||||
if self.audio_format is None:
|
||||
|
@ -521,6 +523,51 @@ def read_duration(audio_type, audio_file):
|
|||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
def read_wav_format(wav_file):
|
||||
wav_file.seek(0)
|
||||
with wave.open(wav_file, 'rb') as wav_file_reader:
|
||||
return read_audio_format_from_wav_file(wav_file_reader)
|
||||
|
||||
|
||||
def read_opus_format(opus_file):
|
||||
_, audio_format = read_opus_header(opus_file)
|
||||
return audio_format
|
||||
|
||||
|
||||
def read_ogg_opus_format(ogg_file):
|
||||
error = ctypes.c_int()
|
||||
ogg_file_buffer = ogg_file.getbuffer()
|
||||
ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer)
|
||||
opusfile = pyogg.opus.op_open_memory(
|
||||
ubyte_array.from_buffer(ogg_file_buffer),
|
||||
len(ogg_file_buffer),
|
||||
ctypes.pointer(error)
|
||||
)
|
||||
|
||||
if error.value != 0:
|
||||
raise ValueError(
|
||||
("Ogg/Opus buffer could not be read."
|
||||
"Error code: {}").format(error.value)
|
||||
)
|
||||
|
||||
channel_count = pyogg.opus.op_channel_count(opusfile, -1)
|
||||
pyogg.opus.op_free(opusfile)
|
||||
|
||||
sample_rate = 48000 # opus files are always 48kHz
|
||||
sample_width = 2 # always 16-bit
|
||||
return AudioFormat(sample_rate, channel_count, sample_width)
|
||||
|
||||
|
||||
def read_format(audio_type, audio_file):
|
||||
if audio_type == AUDIO_TYPE_WAV:
|
||||
return read_wav_format(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OPUS:
|
||||
return read_opus_format(audio_file)
|
||||
if audio_type == AUDIO_TYPE_OGG_OPUS:
|
||||
return read_ogg_opus_format(audio_file)
|
||||
raise ValueError('Unsupported audio type: {}'.format(audio_type))
|
||||
|
||||
|
||||
def get_dtype(audio_format):
|
||||
if audio_format.width not in [1, 2, 4]:
|
||||
raise ValueError('Unsupported sample width: {}'.format(audio_format.width))
|
||||
|
|
|
@ -11,7 +11,6 @@ from functools import partial
|
|||
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
|
||||
from .audio import (
|
||||
Sample,
|
||||
DEFAULT_FORMAT,
|
||||
AUDIO_TYPE_PCM,
|
||||
AUDIO_TYPE_OPUS,
|
||||
SERIALIZABLE_AUDIO_TYPES,
|
||||
|
@ -40,7 +39,7 @@ CONTENT_TYPE_TRANSCRIPT = 'transcript'
|
|||
class LabeledSample(Sample):
|
||||
"""In-memory labeled audio sample representing an utterance.
|
||||
Derived from util.audio.Sample and used by sample collection readers and writers."""
|
||||
def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None):
|
||||
def __init__(self, audio_type, raw_data, transcript, audio_format=None, sample_id=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
|
Loading…
Reference in New Issue