Load alphabet alongside checkpoint if present, some config fixes/cleanup
This commit is contained in:
parent
87f0a371b1
commit
2b5a844c05
|
@ -688,9 +688,6 @@ def early_training_checks():
|
||||||
"for loading and saving."
|
"for loading and saving."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not Config.alphabet_config_path and not Config.bytes_output_mode:
|
|
||||||
raise RuntimeError("Missing --alphabet_config_path flag, can't continue")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
initialize_globals_from_cli()
|
initialize_globals_from_cli()
|
||||||
|
|
|
@ -118,6 +118,12 @@ class _SttConfig(Coqpit):
|
||||||
if not self.available_devices:
|
if not self.available_devices:
|
||||||
self.available_devices = [self.cpu_device]
|
self.available_devices = [self.cpu_device]
|
||||||
|
|
||||||
|
# If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified,
|
||||||
|
# look for alphabet file alongside loaded checkpoint.
|
||||||
|
checkpoint_alphabet_file = os.path.join(
|
||||||
|
self.load_checkpoint_dir, "alphabet.txt"
|
||||||
|
)
|
||||||
|
|
||||||
if self.bytes_output_mode and self.alphabet_config_path:
|
if self.bytes_output_mode and self.alphabet_config_path:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"You cannot set --alphabet_config_path *and* --bytes_output_mode"
|
"You cannot set --alphabet_config_path *and* --bytes_output_mode"
|
||||||
|
@ -126,6 +132,15 @@ class _SttConfig(Coqpit):
|
||||||
self.alphabet = UTF8Alphabet()
|
self.alphabet = UTF8Alphabet()
|
||||||
elif self.alphabet_config_path:
|
elif self.alphabet_config_path:
|
||||||
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
|
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
|
||||||
|
elif os.path.exists(checkpoint_alphabet_file):
|
||||||
|
print(
|
||||||
|
"I --alphabet_config_path not specified, but found an alphabet file "
|
||||||
|
f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n"
|
||||||
|
"I Will use this alphabet file for this run."
|
||||||
|
)
|
||||||
|
self.alphabet = Alphabet(checkpoint_alphabet_file)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Missing --alphabet_config_path flag, can't continue")
|
||||||
|
|
||||||
# Geometric Constants
|
# Geometric Constants
|
||||||
# ===================
|
# ===================
|
||||||
|
@ -157,15 +172,12 @@ class _SttConfig(Coqpit):
|
||||||
self.n_hidden_3 = self.n_cell_dim
|
self.n_hidden_3 = self.n_cell_dim
|
||||||
|
|
||||||
# Dims in last layer = number of characters in alphabet plus one
|
# Dims in last layer = number of characters in alphabet plus one
|
||||||
try:
|
# +1 for CTC blank label
|
||||||
# +1 for CTC blank label
|
self.n_hidden_6 = self.alphabet.GetSize() + 1
|
||||||
self.n_hidden_6 = self.alphabet.GetSize() + 1
|
|
||||||
except:
|
|
||||||
AttributeError
|
|
||||||
|
|
||||||
# Size of audio window in samples
|
# Size of audio window in samples
|
||||||
if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0:
|
if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0:
|
||||||
log_error(
|
raise RuntimeError(
|
||||||
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||||
"your --feature_win_len value or resample your audio accordingly."
|
"your --feature_win_len value or resample your audio accordingly."
|
||||||
|
@ -175,7 +187,6 @@ class _SttConfig(Coqpit):
|
||||||
self.audio_sample_rate,
|
self.audio_sample_rate,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
self.audio_window_samples = self.audio_sample_rate * (
|
self.audio_window_samples = self.audio_sample_rate * (
|
||||||
self.feature_win_len / 1000
|
self.feature_win_len / 1000
|
||||||
|
@ -183,7 +194,7 @@ class _SttConfig(Coqpit):
|
||||||
|
|
||||||
# Stride for feature computations in samples
|
# Stride for feature computations in samples
|
||||||
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
|
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
|
||||||
log_error(
|
raise RuntimeError(
|
||||||
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
||||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||||
"your --feature_win_step value or resample your audio accordingly."
|
"your --feature_win_step value or resample your audio accordingly."
|
||||||
|
@ -193,19 +204,18 @@ class _SttConfig(Coqpit):
|
||||||
self.audio_sample_rate,
|
self.audio_sample_rate,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
self.audio_step_samples = self.audio_sample_rate * (
|
self.audio_step_samples = self.audio_sample_rate * (
|
||||||
self.feature_win_step / 1000
|
self.feature_win_step / 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.one_shot_infer:
|
if self.one_shot_infer and not path_exists_remote(self.one_shot_infer):
|
||||||
if not path_exists_remote(self.one_shot_infer):
|
raise RuntimeError(
|
||||||
log_error("Path specified in --one_shot_infer is not a valid file.")
|
"Path specified in --one_shot_infer is not a valid file."
|
||||||
sys.exit(1)
|
)
|
||||||
|
|
||||||
if self.train_cudnn and self.load_cudnn:
|
if self.train_cudnn and self.load_cudnn:
|
||||||
log_error(
|
raise RuntimeError(
|
||||||
"Trying to use --train_cudnn, but --load_cudnn "
|
"Trying to use --train_cudnn, but --load_cudnn "
|
||||||
"was also specified. The --load_cudnn flag is only "
|
"was also specified. The --load_cudnn flag is only "
|
||||||
"needed when converting a CuDNN RNN checkpoint to "
|
"needed when converting a CuDNN RNN checkpoint to "
|
||||||
|
@ -213,7 +223,6 @@ class _SttConfig(Coqpit):
|
||||||
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
||||||
"checkpoint normally with --save_checkpoint_dir."
|
"checkpoint normally with --save_checkpoint_dir."
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# sphinx-doc: training_ref_flags_start
|
# sphinx-doc: training_ref_flags_start
|
||||||
train_files: List[str] = field(
|
train_files: List[str] = field(
|
||||||
|
@ -727,9 +736,7 @@ class _SttConfig(Coqpit):
|
||||||
|
|
||||||
|
|
||||||
def initialize_globals_from_cli():
|
def initialize_globals_from_cli():
|
||||||
c = _SttConfig()
|
c = _SttConfig.init_from_argparse(arg_prefix="")
|
||||||
c.parse_args(arg_prefix="")
|
|
||||||
c.__post_init__()
|
|
||||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue