Fix #3511: split-sets on sample size
This commit is contained in:
parent
385c8c769b
commit
6945663698
@ -2,6 +2,7 @@
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import unicodedata
|
||||
@ -236,14 +237,18 @@ def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
train_beg = 0
|
||||
train_end = int(0.8 * len(filelist))
|
||||
"""
|
||||
randomply split the datasets into train, validation, and test sets where the size of the
|
||||
validation and test sets are determined by the `get_sample_size` function.
|
||||
"""
|
||||
random.shuffle(filelist)
|
||||
sample_size = get_sample_size(len(filelist))
|
||||
|
||||
dev_beg = int(0.8 * train_end)
|
||||
dev_end = train_end
|
||||
train_end = dev_beg
|
||||
train_beg = 0
|
||||
train_end = len(filelist) - 2 * sample_size
|
||||
|
||||
dev_beg = train_end
|
||||
dev_end = train_end + sample_size
|
||||
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
@ -255,5 +260,25 @@ def _split_sets(filelist):
|
||||
)
|
||||
|
||||
|
||||
def get_sample_size(population_size):
|
||||
"""calculates the sample size for a 99% confidence and 1% margin of error
|
||||
"""
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
return sample_size
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_download_and_preprocess_data(sys.argv[1])
|
||||
|
@ -5,6 +5,7 @@
|
||||
import codecs
|
||||
import fnmatch
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
@ -290,14 +291,18 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
||||
|
||||
|
||||
def _split_sets(filelist):
|
||||
# We initially split the entire set into 80% train and 20% test, then
|
||||
# split the train set into 80% train and 20% validation.
|
||||
train_beg = 0
|
||||
train_end = int(0.8 * len(filelist))
|
||||
"""
|
||||
randomply split the datasets into train, validation, and test sets where the size of the
|
||||
validation and test sets are determined by the `get_sample_size` function.
|
||||
"""
|
||||
random.shuffle(filelist)
|
||||
sample_size = get_sample_size(len(filelist))
|
||||
|
||||
dev_beg = int(0.8 * train_end)
|
||||
dev_end = train_end
|
||||
train_end = dev_beg
|
||||
train_beg = 0
|
||||
train_end = len(filelist) - 2 * sample_size
|
||||
|
||||
dev_beg = train_end
|
||||
dev_end = train_end + sample_size
|
||||
|
||||
test_beg = dev_end
|
||||
test_end = len(filelist)
|
||||
@ -309,6 +314,26 @@ def _split_sets(filelist):
|
||||
)
|
||||
|
||||
|
||||
def get_sample_size(population_size):
|
||||
"""calculates the sample size for a 99% confidence and 1% margin of error
|
||||
"""
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
return sample_size
|
||||
|
||||
|
||||
def _read_data_set(
|
||||
filelist,
|
||||
thread_count,
|
||||
|
Loading…
Reference in New Issue
Block a user