Merge pull request #3546 from dzubke/Iss-3511_split-sets
Fix #3511: split-sets on sample size
This commit is contained in:
commit
8c8b80dc0b
@ -2,6 +2,7 @@
|
|||||||
import codecs
|
import codecs
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import unicodedata
|
import unicodedata
|
||||||
@ -236,14 +237,18 @@ def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file):
|
|||||||
|
|
||||||
|
|
||||||
def _split_sets(filelist):
|
def _split_sets(filelist):
|
||||||
# We initially split the entire set into 80% train and 20% test, then
|
"""
|
||||||
# split the train set into 80% train and 20% validation.
|
randomply split the datasets into train, validation, and test sets where the size of the
|
||||||
train_beg = 0
|
validation and test sets are determined by the `get_sample_size` function.
|
||||||
train_end = int(0.8 * len(filelist))
|
"""
|
||||||
|
random.shuffle(filelist)
|
||||||
|
sample_size = get_sample_size(len(filelist))
|
||||||
|
|
||||||
dev_beg = int(0.8 * train_end)
|
train_beg = 0
|
||||||
dev_end = train_end
|
train_end = len(filelist) - 2 * sample_size
|
||||||
train_end = dev_beg
|
|
||||||
|
dev_beg = train_end
|
||||||
|
dev_end = train_end + sample_size
|
||||||
|
|
||||||
test_beg = dev_end
|
test_beg = dev_end
|
||||||
test_end = len(filelist)
|
test_end = len(filelist)
|
||||||
@ -255,5 +260,25 @@ def _split_sets(filelist):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample_size(population_size):
|
||||||
|
"""calculates the sample size for a 99% confidence and 1% margin of error
|
||||||
|
"""
|
||||||
|
margin_of_error = 0.01
|
||||||
|
fraction_picking = 0.50
|
||||||
|
z_score = 2.58 # Corresponds to confidence level 99%
|
||||||
|
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||||
|
margin_of_error ** 2
|
||||||
|
)
|
||||||
|
sample_size = 0
|
||||||
|
for train_size in range(population_size, 0, -1):
|
||||||
|
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||||
|
margin_of_error ** 2 * train_size
|
||||||
|
)
|
||||||
|
sample_size = int(numerator / denominator)
|
||||||
|
if 2 * sample_size + train_size <= population_size:
|
||||||
|
break
|
||||||
|
return sample_size
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_download_and_preprocess_data(sys.argv[1])
|
_download_and_preprocess_data(sys.argv[1])
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
import codecs
|
import codecs
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
@ -290,14 +291,18 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file):
|
|||||||
|
|
||||||
|
|
||||||
def _split_sets(filelist):
|
def _split_sets(filelist):
|
||||||
# We initially split the entire set into 80% train and 20% test, then
|
"""
|
||||||
# split the train set into 80% train and 20% validation.
|
randomply split the datasets into train, validation, and test sets where the size of the
|
||||||
train_beg = 0
|
validation and test sets are determined by the `get_sample_size` function.
|
||||||
train_end = int(0.8 * len(filelist))
|
"""
|
||||||
|
random.shuffle(filelist)
|
||||||
|
sample_size = get_sample_size(len(filelist))
|
||||||
|
|
||||||
dev_beg = int(0.8 * train_end)
|
train_beg = 0
|
||||||
dev_end = train_end
|
train_end = len(filelist) - 2 * sample_size
|
||||||
train_end = dev_beg
|
|
||||||
|
dev_beg = train_end
|
||||||
|
dev_end = train_end + sample_size
|
||||||
|
|
||||||
test_beg = dev_end
|
test_beg = dev_end
|
||||||
test_end = len(filelist)
|
test_end = len(filelist)
|
||||||
@ -309,6 +314,26 @@ def _split_sets(filelist):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sample_size(population_size):
|
||||||
|
"""calculates the sample size for a 99% confidence and 1% margin of error
|
||||||
|
"""
|
||||||
|
margin_of_error = 0.01
|
||||||
|
fraction_picking = 0.50
|
||||||
|
z_score = 2.58 # Corresponds to confidence level 99%
|
||||||
|
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||||
|
margin_of_error ** 2
|
||||||
|
)
|
||||||
|
sample_size = 0
|
||||||
|
for train_size in range(population_size, 0, -1):
|
||||||
|
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||||
|
margin_of_error ** 2 * train_size
|
||||||
|
)
|
||||||
|
sample_size = int(numerator / denominator)
|
||||||
|
if 2 * sample_size + train_size <= population_size:
|
||||||
|
break
|
||||||
|
return sample_size
|
||||||
|
|
||||||
|
|
||||||
def _read_data_set(
|
def _read_data_set(
|
||||||
filelist,
|
filelist,
|
||||||
thread_count,
|
thread_count,
|
||||||
|
Loading…
Reference in New Issue
Block a user