Formatting changes from pre-commit
This commit is contained in:
parent
b6d40a3451
commit
b4827fa462
@ -6,7 +6,7 @@ from coqui_stt_training.train import train, test, early_training_checks
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
# only one GPU for only one training sample
|
||||
os.environ['CUDA_VISIBLE_DEVICES']='0'
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
download_ldc("data/ldc93s1")
|
||||
|
||||
@ -18,7 +18,7 @@ initialize_globals_from_args(
|
||||
test_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
augment=["time_mask"],
|
||||
n_hidden=100,
|
||||
epochs=200
|
||||
epochs=200,
|
||||
)
|
||||
|
||||
early_training_checks()
|
||||
|
||||
@ -79,16 +79,24 @@ class _SttConfig(Coqpit):
|
||||
# from the same dir.
|
||||
|
||||
# if save_checkpoint_dir hasn't been set, or checkpoint_dir is new
|
||||
if (not self.save_checkpoint_dir) or (self.save_checkpoint_dir is not self.checkpoint_dir):
|
||||
if (not self.save_checkpoint_dir) or (
|
||||
self.save_checkpoint_dir is not self.checkpoint_dir
|
||||
):
|
||||
if not self.checkpoint_dir:
|
||||
self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
|
||||
self.checkpoint_dir = xdg.save_data_path(
|
||||
os.path.join("stt", "checkpoints")
|
||||
)
|
||||
self.save_checkpoint_dir = self.checkpoint_dir
|
||||
else:
|
||||
self.save_checkpoint_dir = self.checkpoint_dir
|
||||
# if load_checkpoint_dir hasn't been set, or checkpoint_dir is new
|
||||
if (not self.load_checkpoint_dir) or (self.load_checkpoint_dir is not self.checkpoint_dir):
|
||||
if (not self.load_checkpoint_dir) or (
|
||||
self.load_checkpoint_dir is not self.checkpoint_dir
|
||||
):
|
||||
if not self.checkpoint_dir:
|
||||
self.checkpoint_dir = xdg.load_data_path(os.path.join("stt", "checkpoints"))
|
||||
self.checkpoint_dir = xdg.load_data_path(
|
||||
os.path.join("stt", "checkpoints")
|
||||
)
|
||||
self.load_checkpoint_dir = self.checkpoint_dir
|
||||
else:
|
||||
self.load_checkpoint_dir = self.checkpoint_dir
|
||||
@ -138,10 +146,12 @@ class _SttConfig(Coqpit):
|
||||
# please refer to doc/Geometry.md
|
||||
|
||||
# Number of MFCC features
|
||||
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
||||
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
||||
|
||||
# The number of frames in the context
|
||||
self.n_context = 9 # TODO: Determine the optimal value using a validation data set
|
||||
self.n_context = (
|
||||
9 # TODO: Determine the optimal value using a validation data set
|
||||
)
|
||||
|
||||
# Number of units in hidden layers
|
||||
self.n_hidden = self.n_hidden
|
||||
@ -171,11 +181,17 @@ class _SttConfig(Coqpit):
|
||||
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_len value or resample your audio accordingly."
|
||||
"".format(self.feature_win_len, self.feature_win_len / 1000, self.audio_sample_rate)
|
||||
"".format(
|
||||
self.feature_win_len,
|
||||
self.feature_win_len / 1000,
|
||||
self.audio_sample_rate,
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
self.audio_window_samples = self.audio_sample_rate * (self.feature_win_len / 1000)
|
||||
self.audio_window_samples = self.audio_sample_rate * (
|
||||
self.feature_win_len / 1000
|
||||
)
|
||||
|
||||
# Stride for feature computations in samples
|
||||
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
|
||||
@ -184,12 +200,16 @@ class _SttConfig(Coqpit):
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_step value or resample your audio accordingly."
|
||||
"".format(
|
||||
self.feature_win_step, self.feature_win_step / 1000, self.audio_sample_rate
|
||||
self.feature_win_step,
|
||||
self.feature_win_step / 1000,
|
||||
self.audio_sample_rate,
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
self.audio_step_samples = self.audio_sample_rate * (self.feature_win_step / 1000)
|
||||
self.audio_step_samples = self.audio_sample_rate * (
|
||||
self.feature_win_step / 1000
|
||||
)
|
||||
|
||||
if self.one_shot_infer:
|
||||
if not path_exists_remote(self.one_shot_infer):
|
||||
@ -718,6 +738,7 @@ def initialize_globals_from_cli():
|
||||
c.__post_init__()
|
||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
||||
|
||||
def initialize_globals_from_args(**override_args):
|
||||
# Update Config with new args
|
||||
c = _SttConfig(**override_args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user