Fix checkpoint setting logic

This commit is contained in:
Josh Meyer 2021-07-21 11:53:11 -04:00
parent 90ce16fa15
commit 3438dd2beb

View File

@ -75,8 +75,23 @@ class _SttConfig(Coqpit):
self.dropout_rate6 = self.dropout_rate
# Set default checkpoint dir
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
# If separate save and load flags were not specified, default to load and save
# from the same dir.
# if save_checkpoint_dir hasn't been set, or checkpoint_dir is new
if (not self.save_checkpoint_dir) or (self.save_checkpoint_dir is not self.checkpoint_dir):
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
self.save_checkpoint_dir = self.checkpoint_dir
else:
self.save_checkpoint_dir = self.checkpoint_dir
# if load_checkpoint_dir hasn't been set, or checkpoint_dir is new
if (not self.load_checkpoint_dir) or (self.load_checkpoint_dir is not self.checkpoint_dir):
if not self.checkpoint_dir:
self.checkpoint_dir = xdg.load_data_path(os.path.join("stt", "checkpoints"))
self.load_checkpoint_dir = self.checkpoint_dir
else:
self.load_checkpoint_dir = self.checkpoint_dir
if self.load_train not in ["last", "best", "init", "auto"]:
self.load_train = "auto"
@ -192,14 +207,6 @@ class _SttConfig(Coqpit):
)
sys.exit(1)
# If separate save and load flags were not specified, default to load and save
# from the same dir.
if not self.save_checkpoint_dir:
self.save_checkpoint_dir = self.checkpoint_dir
if not self.load_checkpoint_dir:
self.load_checkpoint_dir = self.checkpoint_dir
train_files: List[str] = field(
default_factory=list,
metadata=dict(