diff --git a/training/coqui_stt_training/export.py b/training/coqui_stt_training/export.py index 782835bf..ad44fa54 100644 --- a/training/coqui_stt_training/export.py +++ b/training/coqui_stt_training/export.py @@ -240,6 +240,13 @@ def export_savedmodel(): ) builder.save() + + # Copy scorer and alphabet alongside SavedModel + if Config.scorer_path: + print(f"Saving {Config.scorer_path} to {Config.export_dir}") + shutil.copy(Config.scorer_path, Config.export_dir) + shutil.copy(Config.effective_alphabet_path, Config.export_dir) + log_info(f"Exported SavedModel to {Config.export_dir}") diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index d396fd7f..8451e5f0 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -11,8 +11,6 @@ DESIRED_LOG_LEVEL = ( ) os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL -import json -import shutil import time from datetime import datetime from pathlib import Path @@ -59,11 +57,7 @@ from .util.config import ( ) from .util.feeding import create_dataset from .util.helpers import check_ctcdecoder_version -from .util.io import ( - is_remote_path, - open_remote, - remove_remote, -) +from .util.io import remove_remote # Accuracy and Loss @@ -416,18 +410,6 @@ def train(): best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev") - # Save flags next to checkpoints - if not is_remote_path(Config.save_checkpoint_dir): - os.makedirs(Config.save_checkpoint_dir, exist_ok=True) - flags_file = os.path.join(Config.save_checkpoint_dir, "flags.txt") - with open_remote(flags_file, "w") as fout: - json.dump(Config.serialize(), fout, indent=2) - - # Serialize alphabet alongside checkpoint - preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt") - with open_remote(preserved_alphabet_file, "wb") as fout: - fout.write(Config.alphabet.SerializeText()) - with tfv1.Session(config=Config.session_config) as session: log_debug("Session opened.") diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 9ca10b80..5c9a189b 100644 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function +import json import os import sys from dataclasses import asdict, dataclass, field @@ -17,7 +18,7 @@ from .augmentations import NormalizeSampleRate, parse_augmentations from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input from .gpu import get_available_gpus from .helpers import parse_file_size -from .io import path_exists_remote +from .io import is_remote_path, open_remote, path_exists_remote class _ConfigSingleton: @@ -161,6 +162,7 @@ class BaseSttConfig(Coqpit): self.alphabet = UTF8Alphabet() elif self.alphabet_config_path: self.alphabet = Alphabet(self.alphabet_config_path) + self.effective_alphabet_path = self.alphabet_config_path elif os.path.exists(loaded_checkpoint_alphabet_file): print( "I --alphabet_config_path not specified, but found an alphabet file " @@ -168,6 +170,7 @@ class BaseSttConfig(Coqpit): "Will use this alphabet file for this run." ) self.alphabet = Alphabet(loaded_checkpoint_alphabet_file) + self.effective_alphabet_path = loaded_checkpoint_alphabet_file elif self.train_files and self.dev_files and self.test_files: # If all subsets are in the same folder and there's an alphabet file # alongside them, use it. @@ -185,6 +188,7 @@ class BaseSttConfig(Coqpit): "Will use this alphabet file for this run." ) self.alphabet = Alphabet(str(possible_alphabet)) + self.effective_alphabet_path = possible_alphabet if not self.alphabet: # Generate alphabet automatically from input dataset, but only if @@ -199,6 +203,7 @@ class BaseSttConfig(Coqpit): characters, alphabet = create_alphabet_from_sources(sources) print(f"I Generated alphabet characters: {characters}.") self.alphabet = alphabet + self.effective_alphabet_path = saved_checkpoint_alphabet_file else: raise RuntimeError( "Missing --alphabet_config_path flag. Couldn't find an alphabet file " @@ -208,6 +213,17 @@ class BaseSttConfig(Coqpit): "be generated automatically." ) + # Save flags next to checkpoints + if not is_remote_path(self.save_checkpoint_dir): + os.makedirs(self.save_checkpoint_dir, exist_ok=True) + flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt") + with open_remote(flags_file, "w") as fout: + json.dump(self.serialize(), fout, indent=2) + + # Serialize alphabet alongside checkpoint + with open_remote(saved_checkpoint_alphabet_file, "wb") as fout: + fout.write(self.alphabet.SerializeText()) + # Geometric Constants # ===================