Move checking logic into __post_init__()

This commit is contained in:
Josh Meyer 2021-07-21 05:00:05 -04:00
parent 5b4fa27467
commit 4dc565beca
5 changed files with 174 additions and 174 deletions

View File

@ -10,7 +10,7 @@ import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer
from coqui_stt_training.evaluate import evaluate
from coqui_stt_training.train import create_model
from coqui_stt_training.util.config import Config, _SttConfig, initialize_config_globals
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import log_error
@ -52,8 +52,7 @@ def objective(trial):
def main(_):
Config = _SttConfig()
initialize_config_globals(Config)
initialize_globals_from_cli()
if not FLAGS.test_files:
log_error(

View File

@ -17,9 +17,8 @@ from .util.augmentations import NormalizeSampleRate
from .util.checkpoints import load_graph_for_evaluation
from .util.config import (
Config,
_SttConfig,
create_progressbar,
initialize_config_globals,
initialize_globals_from_cli,
log_error,
log_progress,
)
@ -170,8 +169,7 @@ def evaluate(test_csvs, create_model):
def main():
Config = _SttConfig()
initialize_config_globals(Config)
initialize_globals_from_cli()
if not Config.test_files:
log_error(

View File

@ -44,9 +44,8 @@ from .util.checkpoints import (
)
from .util.config import (
Config,
_SttConfig,
create_progressbar,
initialize_config_globals,
initialize_globals_from_cli,
log_debug,
log_error,
log_info,
@ -1250,10 +1249,7 @@ def early_training_checks():
def main():
Config = _SttConfig()
Config.parse_args(arg_prefix="") # parse CLI args
initialize_config_globals(Config)
initialize_globals_from_cli()
early_training_checks()
if Config.train_files:

View File

@ -36,6 +36,163 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
@dataclass
class _SttConfig(Coqpit):
def __post_init__(self):
# Augmentations
self.augmentations = parse_augmentations(self.augment)
if self.augmentations:
print(f"Parsed augmentations: {self.augmentations}")
if self.augmentations and self.feature_cache and self.cache_for_epochs == 0:
print(
"Due to current feature-cache settings the exact same sample augmentations of the first "
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
)
if self.normalize_sample_rate:
self.augmentations = [NormalizeSampleRate(self.audio_sample_rate)] + self[
"augmentations"
]
# Caching
if self.cache_for_epochs == 1:
print(
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
"You can either set --cache_for_epochs > 1, or not use feature caching at all."
)
# Read-buffer
self.read_buffer = parse_file_size(self.read_buffer)
# Set default dropout rates
if self.dropout_rate2 < 0:
self.dropout_rate2 = self.dropout_rate
if self.dropout_rate3 < 0:
self.dropout_rate3 = self.dropout_rate
if self.dropout_rate6 < 0:
self.dropout_rate6 = self.dropout_rate
# Set default checkpoint dir
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
if self.load_train not in ["last", "best", "init", "auto"]:
self.load_train = "auto"
if self.load_evaluate not in ["last", "best", "auto"]:
self.load_evaluate = "auto"
# Set default summary dir
if not self.summary_dir:
self.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
# Standard session configuration that'll be used for all new sessions.
self.session_config = tfv1.ConfigProto(
allow_soft_placement=True,
log_device_placement=self.log_placement,
inter_op_parallelism_threads=self.inter_op_parallelism_threads,
intra_op_parallelism_threads=self.intra_op_parallelism_threads,
gpu_options=tfv1.GPUOptions(allow_growth=self.use_allow_growth),
)
# CPU device
self.cpu_device = "/cpu:0"
# Available GPU devices
self.available_devices = get_available_gpus(self.session_config)
# If there is no GPU available, we fall back to CPU based operation
if not self.available_devices:
self.available_devices = [self.cpu_device]
if self.bytes_output_mode:
self.alphabet = UTF8Alphabet()
elif self.alphabet_config_path:
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
# Geometric Constants
# ===================
# For an explanation of the meaning of the geometric constants
# please refer to doc/Geometry.md
# Number of MFCC features
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
# The number of frames in the context
self.n_context = 9 # TODO: Determine the optimal value using a validation data set
# Number of units in hidden layers
self.n_hidden = self.n_hidden
self.n_hidden_1 = self.n_hidden
self.n_hidden_2 = self.n_hidden
self.n_hidden_5 = self.n_hidden
# LSTM cell state dimension
self.n_cell_dim = self.n_hidden
# The number of units in the third layer, which feeds in to the LSTM
self.n_hidden_3 = self.n_cell_dim
# Dims in last layer = number of characters in alphabet plus one
try:
# +1 for CTC blank label
self.n_hidden_6 = self.alphabet.GetSize() + 1
except:
AttributeError
# Size of audio window in samples
if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0:
log_error(
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_len value or resample your audio accordingly."
"".format(self.feature_win_len, self.feature_win_len / 1000, self.audio_sample_rate)
)
sys.exit(1)
self.audio_window_samples = self.audio_sample_rate * (self.feature_win_len / 1000)
# Stride for feature computations in samples
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
log_error(
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_step value or resample your audio accordingly."
"".format(
self.feature_win_step, self.feature_win_step / 1000, self.audio_sample_rate
)
)
sys.exit(1)
self.audio_step_samples = self.audio_sample_rate * (self.feature_win_step / 1000)
if self.one_shot_infer:
if not path_exists_remote(self.one_shot_infer):
log_error("Path specified in --one_shot_infer is not a valid file.")
sys.exit(1)
if self.train_cudnn and self.load_cudnn:
log_error(
"Trying to use --train_cudnn, but --load_cudnn "
"was also specified. The --load_cudnn flag is only "
"needed when converting a CuDNN RNN checkpoint to "
"a CPU-capable graph. If your system is capable of "
"using CuDNN RNN, you can just specify the CuDNN RNN "
"checkpoint normally with --save_checkpoint_dir."
)
sys.exit(1)
# If separate save and load flags were not specified, default to load and save
# from the same dir.
if not self.save_checkpoint_dir:
self.save_checkpoint_dir = self.checkpoint_dir
if not self.load_checkpoint_dir:
self.load_checkpoint_dir = self.checkpoint_dir
train_files: List[str] = field(
default_factory=list,
metadata=dict(
@ -541,165 +698,15 @@ class _SttConfig(Coqpit):
)
def initialize_config_globals(c):
"""
input: config class object (i.e. coqpit.Coqpit)
"""
# Augmentations
c.augmentations = parse_augmentations(c.augment)
print(f"Parsed augmentations from flags: {c.augmentations}")
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
print(
"Due to current feature-cache settings the exact same sample augmentations of the first "
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
)
if c.normalize_sample_rate:
c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[
"augmentations"
]
# Caching
if c.cache_for_epochs == 1:
print(
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
)
# Read-buffer
c.read_buffer = parse_file_size(c.read_buffer)
# Set default dropout rates
if c.dropout_rate2 < 0:
c.dropout_rate2 = c.dropout_rate
if c.dropout_rate3 < 0:
c.dropout_rate3 = c.dropout_rate
if c.dropout_rate6 < 0:
c.dropout_rate6 = c.dropout_rate
# Set default checkpoint dir
if not c.checkpoint_dir:
c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
if c.load_train not in ["last", "best", "init", "auto"]:
c.load_train = "auto"
if c.load_evaluate not in ["last", "best", "auto"]:
c.load_evaluate = "auto"
# Set default summary dir
if not c.summary_dir:
c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
# Standard session configuration that'll be used for all new sessions.
c.session_config = tfv1.ConfigProto(
allow_soft_placement=True,
log_device_placement=c.log_placement,
inter_op_parallelism_threads=c.inter_op_parallelism_threads,
intra_op_parallelism_threads=c.intra_op_parallelism_threads,
gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth),
)
# CPU device
c.cpu_device = "/cpu:0"
# Available GPU devices
c.available_devices = get_available_gpus(c.session_config)
# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]
if c.bytes_output_mode:
c.alphabet = UTF8Alphabet()
elif c.alphabet_config_path:
c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path))
# Geometric Constants
# ===================
# For an explanation of the meaning of the geometric constants, please refer to
# doc/Geometry.md
# Number of MFCC features
c.n_input = 26 # TODO: Determine this programmatically from the sample rate
# The number of frames in the context
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
# Number of units in hidden layers
c.n_hidden = c.n_hidden
c.n_hidden_1 = c.n_hidden
c.n_hidden_2 = c.n_hidden
c.n_hidden_5 = c.n_hidden
# LSTM cell state dimension
c.n_cell_dim = c.n_hidden
# The number of units in the third layer, which feeds in to the LSTM
c.n_hidden_3 = c.n_cell_dim
# Units in the last layer = number of characters in the alphabet plus one
try:
# +1 for CTC blank label
c.n_hidden_6 = c.alphabet.GetSize() + 1
except:
AttributeError
# Size of audio window in samples
if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0:
log_error(
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_len value or resample your audio accordingly."
"".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate)
)
sys.exit(1)
c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000)
# Stride for feature computations in samples
if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0:
log_error(
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_step value or resample your audio accordingly."
"".format(
c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate
)
)
sys.exit(1)
c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000)
if c.one_shot_infer:
if not path_exists_remote(c.one_shot_infer):
log_error("Path specified in --one_shot_infer is not a valid file.")
sys.exit(1)
if c.train_cudnn and c.load_cudnn:
log_error(
"Trying to use --train_cudnn, but --load_cudnn "
"was also specified. The --load_cudnn flag is only "
"needed when converting a CuDNN RNN checkpoint to "
"a CPU-capable graph. If your system is capable of "
"using CuDNN RNN, you can just specify the CuDNN RNN "
"checkpoint normally with --save_checkpoint_dir."
)
sys.exit(1)
# If separate save and load flags were not specified, default to load and save
# from the same dir.
if not c.save_checkpoint_dir:
c.save_checkpoint_dir = c.checkpoint_dir
if not c.load_checkpoint_dir:
c.load_checkpoint_dir = c.checkpoint_dir
def initialize_globals_from_cli():
c = _SttConfig()
c.parse_args(arg_prefix="")
c.__post_init__()
_ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_args(**override_args):
# Update Config with new args
c = _SttConfig(**override_args)
_ConfigSingleton._config = c # pylint: disable=protected-access

View File

@ -20,7 +20,7 @@ from multiprocessing import Process, cpu_count
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.config import Config, _SttConfig, initialize_config_globals
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import (
@ -42,8 +42,8 @@ def transcribe_file(audio_path, tlog_path):
)
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
Config = _SttConfig()
initialize_config_globals(Config)
initialize_globals_from_cli()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
num_processes = cpu_count()