Merge pull request #3546 from dzubke/Iss-3511_split-sets

Fix #3511: split-sets on sample size
This commit is contained in:
Reuben Morais 2021-03-01 18:09:38 +00:00 committed by GitHub
commit 8c8b80dc0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 14 deletions

View File

@ -2,6 +2,7 @@
import codecs
import fnmatch
import os
import random
import subprocess
import sys
import unicodedata
@ -236,14 +237,18 @@ def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
train_beg = 0
train_end = int(0.8 * len(filelist))
"""
randomply split the datasets into train, validation, and test sets where the size of the
validation and test sets are determined by the `get_sample_size` function.
"""
random.shuffle(filelist)
sample_size = get_sample_size(len(filelist))
dev_beg = int(0.8 * train_end)
dev_end = train_end
train_end = dev_beg
train_beg = 0
train_end = len(filelist) - 2 * sample_size
dev_beg = train_end
dev_end = train_end + sample_size
test_beg = dev_end
test_end = len(filelist)
@ -255,5 +260,25 @@ def _split_sets(filelist):
)
def get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error
"""
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1])

View File

@ -5,6 +5,7 @@
import codecs
import fnmatch
import os
import random
import subprocess
import sys
import tarfile
@ -290,14 +291,18 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
train_beg = 0
train_end = int(0.8 * len(filelist))
"""
randomply split the datasets into train, validation, and test sets where the size of the
validation and test sets are determined by the `get_sample_size` function.
"""
random.shuffle(filelist)
sample_size = get_sample_size(len(filelist))
dev_beg = int(0.8 * train_end)
dev_end = train_end
train_end = dev_beg
train_beg = 0
train_end = len(filelist) - 2 * sample_size
dev_beg = train_end
dev_end = train_end + sample_size
test_beg = dev_end
test_end = len(filelist)
@ -309,6 +314,26 @@ def _split_sets(filelist):
)
def get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error
"""
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
def _read_data_set(
filelist,
thread_count,