Formatting changes from pre-commit

This commit is contained in:
Josh Meyer 2021-07-22 05:39:45 -04:00
parent b6d40a3451
commit b4827fa462
2 changed files with 33 additions and 12 deletions

View File

@ -6,7 +6,7 @@ from coqui_stt_training.train import train, test, early_training_checks
import tensorflow.compat.v1 as tfv1
# only one GPU for only one training sample
os.environ['CUDA_VISIBLE_DEVICES']='0'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
download_ldc("data/ldc93s1")
@ -18,7 +18,7 @@ initialize_globals_from_args(
test_files=["data/ldc93s1/ldc93s1.csv"],
augment=["time_mask"],
n_hidden=100,
epochs=200
epochs=200,
)
early_training_checks()

View File

@ -79,16 +79,24 @@ class _SttConfig(Coqpit):
# from the same dir.
# if save_checkpoint_dir hasn't been set, or checkpoint_dir is new
if (not self.save_checkpoint_dir) or (self.save_checkpoint_dir is not self.checkpoint_dir):
if (not self.save_checkpoint_dir) or (
self.save_checkpoint_dir is not self.checkpoint_dir
):
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
self.checkpoint_dir = xdg.save_data_path(
os.path.join("stt", "checkpoints")
)
self.save_checkpoint_dir = self.checkpoint_dir
else:
self.save_checkpoint_dir = self.checkpoint_dir
# if load_checkpoint_dir hasn't been set, or checkpoint_dir is new
if (not self.load_checkpoint_dir) or (self.load_checkpoint_dir is not self.checkpoint_dir):
if (not self.load_checkpoint_dir) or (
self.load_checkpoint_dir is not self.checkpoint_dir
):
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.load_data_path(os.path.join("stt", "checkpoints"))
self.checkpoint_dir = xdg.load_data_path(
os.path.join("stt", "checkpoints")
)
self.load_checkpoint_dir = self.checkpoint_dir
else:
self.load_checkpoint_dir = self.checkpoint_dir
@ -138,10 +146,12 @@ class _SttConfig(Coqpit):
# please refer to doc/Geometry.md
# Number of MFCC features
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
# The number of frames in the context
self.n_context = 9 # TODO: Determine the optimal value using a validation data set
self.n_context = (
9 # TODO: Determine the optimal value using a validation data set
)
# Number of units in hidden layers
self.n_hidden = self.n_hidden
@ -171,11 +181,17 @@ class _SttConfig(Coqpit):
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_len value or resample your audio accordingly."
"".format(self.feature_win_len, self.feature_win_len / 1000, self.audio_sample_rate)
"".format(
self.feature_win_len,
self.feature_win_len / 1000,
self.audio_sample_rate,
)
)
sys.exit(1)
self.audio_window_samples = self.audio_sample_rate * (self.feature_win_len / 1000)
self.audio_window_samples = self.audio_sample_rate * (
self.feature_win_len / 1000
)
# Stride for feature computations in samples
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
@ -184,12 +200,16 @@ class _SttConfig(Coqpit):
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
"your --feature_win_step value or resample your audio accordingly."
"".format(
self.feature_win_step, self.feature_win_step / 1000, self.audio_sample_rate
self.feature_win_step,
self.feature_win_step / 1000,
self.audio_sample_rate,
)
)
sys.exit(1)
self.audio_step_samples = self.audio_sample_rate * (self.feature_win_step / 1000)
self.audio_step_samples = self.audio_sample_rate * (
self.feature_win_step / 1000
)
if self.one_shot_infer:
if not path_exists_remote(self.one_shot_infer):
@ -718,6 +738,7 @@ def initialize_globals_from_cli():
c.__post_init__()
_ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_args(**override_args):
# Update Config with new args
c = _SttConfig(**override_args)