Currently working notebook
This commit is contained in:
parent
a37ca2ec27
commit
59e32556a4
191
train-ldc.ipynb
191
train-ldc.ipynb
@ -1,3 +1,5 @@
|
|||||||
|
# Download LDC data
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -30,17 +32,190 @@ def download_and_preprocess_data(data_dir):
|
|||||||
|
|
||||||
download_and_preprocess_data('/home/STT/data')
|
download_and_preprocess_data('/home/STT/data')
|
||||||
|
|
||||||
|
# Train
|
||||||
|
|
||||||
from STT.training.coqui_stt_training.train import train, early_training_checks
|
from STT.training.coqui_stt_training.util.config import _SttConfig, _ConfigSingleton
|
||||||
from STT.training.coqui_stt_training.util.config import initialize_globals
|
from STT.training.coqui_stt_training.util.augmentations import parse_augmentations, NormalizeSampleRate
|
||||||
|
from STT.training.coqui_stt_training.util.helpers import parse_file_size
|
||||||
|
from STT.training.coqui_stt_training.util.gpu import get_available_gpus
|
||||||
|
from coqui_stt_ctcdecoder import Alphabet
|
||||||
|
from xdg import BaseDirectory as xdg
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
#Config.train_files=['/home/STT/data/ldc.csv']
|
def initialize_globals(c):
|
||||||
#Config.dev_files=['/home/STT/data/ldc.csv']
|
|
||||||
#Config.test_files=['/home/STT/data/ldc.csv']
|
|
||||||
|
|
||||||
#Config.alphabet_config_path='/home/STT/data/alphabet.txt'
|
# Augmentations
|
||||||
initialize_globals()
|
c.augmentations = parse_augmentations(c.augment)
|
||||||
|
print(f"Parsed augmentations from flags: {c.augmentations}")
|
||||||
|
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
||||||
|
print(
|
||||||
|
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
||||||
|
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
|
||||||
|
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
|
||||||
|
)
|
||||||
|
|
||||||
|
if c.normalize_sample_rate:
|
||||||
|
c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[
|
||||||
|
"augmentations"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Caching
|
||||||
|
if c.cache_for_epochs == 1:
|
||||||
|
print(
|
||||||
|
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read-buffer
|
||||||
|
c.read_buffer = parse_file_size(c.read_buffer)
|
||||||
|
|
||||||
|
# Set default dropout rates
|
||||||
|
if c.dropout_rate2 < 0:
|
||||||
|
c.dropout_rate2 = c.dropout_rate
|
||||||
|
if c.dropout_rate3 < 0:
|
||||||
|
c.dropout_rate3 = c.dropout_rate
|
||||||
|
if c.dropout_rate6 < 0:
|
||||||
|
c.dropout_rate6 = c.dropout_rate
|
||||||
|
|
||||||
|
# Set default checkpoint dir
|
||||||
|
if not c.checkpoint_dir:
|
||||||
|
c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
|
||||||
|
|
||||||
|
if c.load_train not in ["last", "best", "init", "auto"]:
|
||||||
|
c.load_train = "auto"
|
||||||
|
|
||||||
|
if c.load_evaluate not in ["last", "best", "auto"]:
|
||||||
|
c.load_evaluate = "auto"
|
||||||
|
|
||||||
|
# Set default summary dir
|
||||||
|
if not c.summary_dir:
|
||||||
|
c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
||||||
|
|
||||||
|
# Standard session configuration that'll be used for all new sessions.
|
||||||
|
c.session_config = tfv1.ConfigProto(
|
||||||
|
allow_soft_placement=True,
|
||||||
|
log_device_placement=c.log_placement,
|
||||||
|
inter_op_parallelism_threads=c.inter_op_parallelism_threads,
|
||||||
|
intra_op_parallelism_threads=c.intra_op_parallelism_threads,
|
||||||
|
gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth),
|
||||||
|
)
|
||||||
|
|
||||||
|
# CPU device
|
||||||
|
c.cpu_device = "/cpu:0"
|
||||||
|
|
||||||
|
# Available GPU devices
|
||||||
|
c.available_devices = get_available_gpus(c.session_config)
|
||||||
|
|
||||||
|
# If there is no GPU available, we fall back to CPU based operation
|
||||||
|
if not c.available_devices:
|
||||||
|
c.available_devices = [c.cpu_device]
|
||||||
|
|
||||||
|
c.alphabet_config_path=""
|
||||||
|
|
||||||
|
if c.bytes_output_mode:
|
||||||
|
c.alphabet = UTF8Alphabet()
|
||||||
|
elif c.alphabet_config_path:
|
||||||
|
c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path))
|
||||||
|
|
||||||
|
# Geometric Constants
|
||||||
|
# ===================
|
||||||
|
|
||||||
|
# For an explanation of the meaning of the geometric constants, please refer to
|
||||||
|
# doc/Geometry.md
|
||||||
|
|
||||||
|
# Number of MFCC features
|
||||||
|
c.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
||||||
|
|
||||||
|
# The number of frames in the context
|
||||||
|
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
|
||||||
|
|
||||||
|
# Number of units in hidden layers
|
||||||
|
c.n_hidden = c.n_hidden
|
||||||
|
|
||||||
|
c.n_hidden_1 = c.n_hidden
|
||||||
|
|
||||||
|
c.n_hidden_2 = c.n_hidden
|
||||||
|
|
||||||
|
c.n_hidden_5 = c.n_hidden
|
||||||
|
|
||||||
|
# LSTM cell state dimension
|
||||||
|
c.n_cell_dim = c.n_hidden
|
||||||
|
|
||||||
|
# The number of units in the third layer, which feeds in to the LSTM
|
||||||
|
c.n_hidden_3 = c.n_cell_dim
|
||||||
|
|
||||||
|
# Units in the sixth layer = number of characters in the target language plus one
|
||||||
|
try:
|
||||||
|
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
||||||
|
except:
|
||||||
|
AttributeError
|
||||||
|
|
||||||
|
# Size of audio window in samples
|
||||||
|
if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0:
|
||||||
|
log_error(
|
||||||
|
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||||
|
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||||
|
"your --feature_win_len value or resample your audio accordingly."
|
||||||
|
"".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate)
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000)
|
||||||
|
|
||||||
|
# Stride for feature computations in samples
|
||||||
|
if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0:
|
||||||
|
log_error(
|
||||||
|
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
||||||
|
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||||
|
"your --feature_win_step value or resample your audio accordingly."
|
||||||
|
"".format(
|
||||||
|
c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000)
|
||||||
|
|
||||||
|
if c.one_shot_infer:
|
||||||
|
if not path_exists_remote(c.one_shot_infer):
|
||||||
|
log_error("Path specified in --one_shot_infer is not a valid file.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if c.train_cudnn and c.load_cudnn:
|
||||||
|
log_error(
|
||||||
|
"Trying to use --train_cudnn, but --load_cudnn "
|
||||||
|
"was also specified. The --load_cudnn flag is only "
|
||||||
|
"needed when converting a CuDNN RNN checkpoint to "
|
||||||
|
"a CPU-capable graph. If your system is capable of "
|
||||||
|
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
||||||
|
"checkpoint normally with --save_checkpoint_dir."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# If separate save and load flags were not specified, default to load and save
|
||||||
|
# from the same dir.
|
||||||
|
if not c.save_checkpoint_dir:
|
||||||
|
c.save_checkpoint_dir = c.checkpoint_dir
|
||||||
|
|
||||||
|
if not c.load_checkpoint_dir:
|
||||||
|
c.load_checkpoint_dir = c.checkpoint_dir
|
||||||
|
|
||||||
|
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from STT.training.coqui_stt_training.train import train, test, early_training_checks
|
||||||
|
|
||||||
|
Config = _SttConfig()
|
||||||
|
|
||||||
|
Config.alphabet = Alphabet('/home/STT/data/alphabet.txt')
|
||||||
|
Config.train_files=['/home/STT/data/ldc93s1.csv']
|
||||||
|
Config.dev_files=['/home/STT/data/ldc93s1.csv']
|
||||||
|
Config.test_files=['/home/STT/data/ldc93s1.csv']
|
||||||
|
Config.n_hidden=100
|
||||||
|
Config.epochs=200
|
||||||
|
|
||||||
|
initialize_globals(Config)
|
||||||
|
|
||||||
|
#print(Config.to_json())
|
||||||
early_training_checks()
|
early_training_checks()
|
||||||
|
|
||||||
train()
|
train()
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
test()
|
||||||
|
Loading…
Reference in New Issue
Block a user