From 31b92801943485bfa5a4065bc6a903f27f241a69 Mon Sep 17 00:00:00 2001 From: Josh Meyer Date: Mon, 16 Aug 2021 06:22:20 -0400 Subject: [PATCH] Allow reading of audio files via str path --- training/coqui_stt_training/util/audio.py | 28 +++++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/training/coqui_stt_training/util/audio.py b/training/coqui_stt_training/util/audio.py index ab96c606..2f2a6868 100644 --- a/training/coqui_stt_training/util/audio.py +++ b/training/coqui_stt_training/util/audio.py @@ -500,7 +500,8 @@ def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): def read_wav(wav_file): - wav_file.seek(0) + if not type(wav_file) is str: + wav_file.seek(0) with wave.open(wav_file, "rb") as wav_file_reader: audio_format = read_audio_format_from_wav_file(wav_file_reader) pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) @@ -530,7 +531,8 @@ def write_audio( def read_wav_duration(wav_file): - wav_file.seek(0) + if not type(wav_file) is str: + wav_file.seek(0) with wave.open(wav_file, "rb") as wav_file_reader: return wav_file_reader.getnframes() / wav_file_reader.getframerate() @@ -542,13 +544,18 @@ def read_opus_duration(opus_file): def read_ogg_opus_duration(ogg_file): error = ctypes.c_int() - ogg_file_buffer = ogg_file.getbuffer() - ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer) - opusfile = pyogg.opus.op_open_memory( - ubyte_array.from_buffer(ogg_file_buffer), - len(ogg_file_buffer), - ctypes.pointer(error), - ) + if type(ogg_file) is str: + opusfile = pyogg.opus.op_open_file( + bytes(ogg_file, encoding="utf-8"), ctypes.pointer(error) + ) + else: + ogg_file_buffer = ogg_file.getbuffer() + ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer) + opusfile = pyogg.opus.op_open_memory( + ubyte_array.from_buffer(ogg_file_buffer), + len(ogg_file_buffer), + ctypes.pointer(error), + ) if error.value != 0: raise ValueError( @@ -575,7 +582,8 @@ def read_duration(audio_type, audio_file): def read_wav_format(wav_file): - wav_file.seek(0) + if not type(wav_file) is str: + wav_file.seek(0) with wave.open(wav_file, "rb") as wav_file_reader: return read_audio_format_from_wav_file(wav_file_reader)