Work remote I/O into audio utils -- a bit more involved

This commit is contained in:
CatalinVoss 2020-11-12 14:17:03 -08:00
parent 3d503bd69e
commit ad08830421

View File

@ -8,6 +8,7 @@ import numpy as np
from .helpers import LimitingPool
from collections import namedtuple
from .io import open_remote, remove_remote, copy_remote, is_remote_path
AudioFormat = namedtuple('AudioFormat', 'rate channels width')
@ -168,29 +169,44 @@ class AudioFile:
self.audio_format = audio_format
self.as_path = as_path
self.open_file = None
self.open_wav = None
self.tmp_file_path = None
def __enter__(self):
if self.audio_path.endswith('.wav'):
self.open_file = wave.open(self.audio_path, 'r')
if read_audio_format_from_wav_file(self.open_file) == self.audio_format:
self.open_file = open_remote(self.audio_path, 'r')
self.open_wav = wave.open(self.open_file)
if read_audio_format_from_wav_file(self.open_wav) == self.audio_format:
if self.as_path:
self.open_wav.close()
self.open_file.close()
return self.audio_path
return self.open_file
return self.open_wav
self.open_wav.close()
self.open_file.close()
# If the format isn't right, copy the file to local tmp dir and do the conversion on disk
if is_remote_path(self.audio_path):
_, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav')
copy_remote(self.audio_path, self.tmp_src_file_path)
self.audio_path = self.tmp_file_path
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav')
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format)
if self.as_path:
return self.tmp_file_path
self.open_file = wave.open(self.tmp_file_path, 'r')
return self.open_file
self.open_wav = wave.open(self.tmp_file_path, 'r')
return self.open_wav
def __exit__(self, *args):
if not self.as_path:
self.open_file.close()
self.open_wav.close()
if self.open_file:
self.open_file.close()
if self.tmp_file_path is not None:
os.remove(self.tmp_file_path)
if self.tmp_src_file_path is not None:
os.remove(self.tmp_src_file_path)
def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
@ -320,7 +336,7 @@ def read_opus(opus_file):
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
with wave.open(wav_file, 'wb') as wav_file_writer:
with wave.open_remote(wav_file, 'wb') as wav_file_writer:
wav_file_writer.setframerate(audio_format.rate)
wav_file_writer.setnchannels(audio_format.channels)
wav_file_writer.setsampwidth(audio_format.width)
@ -329,7 +345,7 @@ def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
def read_wav(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
with wave.open_remote(wav_file, 'rb') as wav_file_reader:
audio_format = read_audio_format_from_wav_file(wav_file_reader)
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
return audio_format, pcm_data
@ -353,7 +369,7 @@ def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, b
def read_wav_duration(wav_file):
wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader:
with wave.open_remote(wav_file, 'rb') as wav_file_reader:
return wav_file_reader.getnframes() / wav_file_reader.getframerate()