Fix #3511: split-sets on sample size

This commit is contained in:
Dustin Zubke 2021-02-28 16:09:37 -05:00
parent 385c8c769b
commit 6945663698
2 changed files with 64 additions and 14 deletions

View File

@ -2,6 +2,7 @@
import codecs import codecs
import fnmatch import fnmatch
import os import os
import random
import subprocess import subprocess
import sys import sys
import unicodedata import unicodedata
@ -236,14 +237,18 @@ def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
def _split_sets(filelist): def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then """
# split the train set into 80% train and 20% validation. randomply split the datasets into train, validation, and test sets where the size of the
train_beg = 0 validation and test sets are determined by the `get_sample_size` function.
train_end = int(0.8 * len(filelist)) """
random.shuffle(filelist)
sample_size = get_sample_size(len(filelist))
dev_beg = int(0.8 * train_end) train_beg = 0
dev_end = train_end train_end = len(filelist) - 2 * sample_size
train_end = dev_beg
dev_beg = train_end
dev_end = train_end + sample_size
test_beg = dev_end test_beg = dev_end
test_end = len(filelist) test_end = len(filelist)
@ -255,5 +260,25 @@ def _split_sets(filelist):
) )
def get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error
"""
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
if __name__ == "__main__": if __name__ == "__main__":
_download_and_preprocess_data(sys.argv[1]) _download_and_preprocess_data(sys.argv[1])

View File

@ -5,6 +5,7 @@
import codecs import codecs
import fnmatch import fnmatch
import os import os
import random
import subprocess import subprocess
import sys import sys
import tarfile import tarfile
@ -290,14 +291,18 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
def _split_sets(filelist): def _split_sets(filelist):
# We initially split the entire set into 80% train and 20% test, then """
# split the train set into 80% train and 20% validation. randomply split the datasets into train, validation, and test sets where the size of the
train_beg = 0 validation and test sets are determined by the `get_sample_size` function.
train_end = int(0.8 * len(filelist)) """
random.shuffle(filelist)
sample_size = get_sample_size(len(filelist))
dev_beg = int(0.8 * train_end) train_beg = 0
dev_end = train_end train_end = len(filelist) - 2 * sample_size
train_end = dev_beg
dev_beg = train_end
dev_end = train_end + sample_size
test_beg = dev_end test_beg = dev_end
test_end = len(filelist) test_end = len(filelist)
@ -309,6 +314,26 @@ def _split_sets(filelist):
) )
def get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error
"""
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2
)
sample_size = 0
for train_size in range(population_size, 0, -1):
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
margin_of_error ** 2 * train_size
)
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return sample_size
def _read_data_set( def _read_data_set(
filelist, filelist,
thread_count, thread_count,