Read audio format from data before running augmentation passes instead of assuming default
This commit is contained in:
		
							parent
							
								
									8c0d46cb7f
								
							
						
					
					
						commit
						79a42b345d
					
				| @ -76,6 +76,8 @@ class Sample: | |||||||
|         if audio_type in SERIALIZABLE_AUDIO_TYPES: |         if audio_type in SERIALIZABLE_AUDIO_TYPES: | ||||||
|             self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) |             self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) | ||||||
|             self.duration = read_duration(audio_type, self.audio) |             self.duration = read_duration(audio_type, self.audio) | ||||||
|  |             if not self.audio_format: | ||||||
|  |                 self.audio_format = read_format(audio_type, self.audio) | ||||||
|         else: |         else: | ||||||
|             self.audio = raw_data |             self.audio = raw_data | ||||||
|             if self.audio_format is None: |             if self.audio_format is None: | ||||||
| @ -521,6 +523,51 @@ def read_duration(audio_type, audio_file): | |||||||
|     raise ValueError('Unsupported audio type: {}'.format(audio_type)) |     raise ValueError('Unsupported audio type: {}'.format(audio_type)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def read_wav_format(wav_file): | ||||||
|  |     wav_file.seek(0) | ||||||
|  |     with wave.open(wav_file, 'rb') as wav_file_reader: | ||||||
|  |         return read_audio_format_from_wav_file(wav_file_reader) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def read_opus_format(opus_file): | ||||||
|  |     _, audio_format = read_opus_header(opus_file) | ||||||
|  |     return audio_format | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def read_ogg_opus_format(ogg_file): | ||||||
|  |     error = ctypes.c_int() | ||||||
|  |     ogg_file_buffer = ogg_file.getbuffer() | ||||||
|  |     ubyte_array = ctypes.c_ubyte * len(ogg_file_buffer) | ||||||
|  |     opusfile = pyogg.opus.op_open_memory( | ||||||
|  |         ubyte_array.from_buffer(ogg_file_buffer), | ||||||
|  |         len(ogg_file_buffer), | ||||||
|  |         ctypes.pointer(error) | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     if error.value != 0: | ||||||
|  |         raise ValueError( | ||||||
|  |             ("Ogg/Opus buffer could not be read." | ||||||
|  |              "Error code: {}").format(error.value) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     channel_count = pyogg.opus.op_channel_count(opusfile, -1) | ||||||
|  |     pyogg.opus.op_free(opusfile) | ||||||
|  | 
 | ||||||
|  |     sample_rate = 48000 # opus files are always 48kHz | ||||||
|  |     sample_width = 2 # always 16-bit | ||||||
|  |     return AudioFormat(sample_rate, channel_count, sample_width) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def read_format(audio_type, audio_file): | ||||||
|  |     if audio_type == AUDIO_TYPE_WAV: | ||||||
|  |         return read_wav_format(audio_file) | ||||||
|  |     if audio_type == AUDIO_TYPE_OPUS: | ||||||
|  |         return read_opus_format(audio_file) | ||||||
|  |     if audio_type == AUDIO_TYPE_OGG_OPUS: | ||||||
|  |         return read_ogg_opus_format(audio_file) | ||||||
|  |     raise ValueError('Unsupported audio type: {}'.format(audio_type)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def get_dtype(audio_format): | def get_dtype(audio_format): | ||||||
|     if audio_format.width not in [1, 2, 4]: |     if audio_format.width not in [1, 2, 4]: | ||||||
|         raise ValueError('Unsupported sample width: {}'.format(audio_format.width)) |         raise ValueError('Unsupported sample width: {}'.format(audio_format.width)) | ||||||
|  | |||||||
| @ -11,7 +11,6 @@ from functools import partial | |||||||
| from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap | from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap | ||||||
| from .audio import ( | from .audio import ( | ||||||
|     Sample, |     Sample, | ||||||
|     DEFAULT_FORMAT, |  | ||||||
|     AUDIO_TYPE_PCM, |     AUDIO_TYPE_PCM, | ||||||
|     AUDIO_TYPE_OPUS, |     AUDIO_TYPE_OPUS, | ||||||
|     SERIALIZABLE_AUDIO_TYPES, |     SERIALIZABLE_AUDIO_TYPES, | ||||||
| @ -40,7 +39,7 @@ CONTENT_TYPE_TRANSCRIPT = 'transcript' | |||||||
| class LabeledSample(Sample): | class LabeledSample(Sample): | ||||||
|     """In-memory labeled audio sample representing an utterance. |     """In-memory labeled audio sample representing an utterance. | ||||||
|     Derived from util.audio.Sample and used by sample collection readers and writers.""" |     Derived from util.audio.Sample and used by sample collection readers and writers.""" | ||||||
|     def __init__(self, audio_type, raw_data, transcript, audio_format=DEFAULT_FORMAT, sample_id=None): |     def __init__(self, audio_type, raw_data, transcript, audio_format=None, sample_id=None): | ||||||
|         """ |         """ | ||||||
|         Parameters |         Parameters | ||||||
|         ---------- |         ---------- | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user