Merge pull request #2026 from coqui-ai/save-scorer-alphabet-savedmodel
Save Scorer and alphabet with SavedModel exports
This commit is contained in:
commit
154a67fb2c
|
@ -240,6 +240,13 @@ def export_savedmodel():
|
|||
)
|
||||
|
||||
builder.save()
|
||||
|
||||
# Copy scorer and alphabet alongside SavedModel
|
||||
if Config.scorer_path:
|
||||
print(f"Saving {Config.scorer_path} to {Config.export_dir}")
|
||||
shutil.copy(Config.scorer_path, Config.export_dir)
|
||||
shutil.copy(Config.effective_alphabet_path, Config.export_dir)
|
||||
|
||||
log_info(f"Exported SavedModel to {Config.export_dir}")
|
||||
|
||||
|
||||
|
|
|
@ -11,8 +11,6 @@ DESIRED_LOG_LEVEL = (
|
|||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
@ -59,11 +57,7 @@ from .util.config import (
|
|||
)
|
||||
from .util.feeding import create_dataset
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.io import (
|
||||
is_remote_path,
|
||||
open_remote,
|
||||
remove_remote,
|
||||
)
|
||||
from .util.io import remove_remote
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
|
@ -416,18 +410,6 @@ def train():
|
|||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev")
|
||||
|
||||
# Save flags next to checkpoints
|
||||
if not is_remote_path(Config.save_checkpoint_dir):
|
||||
os.makedirs(Config.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(Config.save_checkpoint_dir, "flags.txt")
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(Config.serialize(), fout, indent=2)
|
||||
|
||||
# Serialize alphabet alongside checkpoint
|
||||
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
|
||||
with open_remote(preserved_alphabet_file, "wb") as fout:
|
||||
fout.write(Config.alphabet.SerializeText())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug("Session opened.")
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
@ -17,7 +18,7 @@ from .augmentations import NormalizeSampleRate, parse_augmentations
|
|||
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
|
||||
from .gpu import get_available_gpus
|
||||
from .helpers import parse_file_size
|
||||
from .io import path_exists_remote
|
||||
from .io import is_remote_path, open_remote, path_exists_remote
|
||||
|
||||
|
||||
class _ConfigSingleton:
|
||||
|
@ -161,6 +162,7 @@ class BaseSttConfig(Coqpit):
|
|||
self.alphabet = UTF8Alphabet()
|
||||
elif self.alphabet_config_path:
|
||||
self.alphabet = Alphabet(self.alphabet_config_path)
|
||||
self.effective_alphabet_path = self.alphabet_config_path
|
||||
elif os.path.exists(loaded_checkpoint_alphabet_file):
|
||||
print(
|
||||
"I --alphabet_config_path not specified, but found an alphabet file "
|
||||
|
@ -168,6 +170,7 @@ class BaseSttConfig(Coqpit):
|
|||
"Will use this alphabet file for this run."
|
||||
)
|
||||
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
|
||||
self.effective_alphabet_path = loaded_checkpoint_alphabet_file
|
||||
elif self.train_files and self.dev_files and self.test_files:
|
||||
# If all subsets are in the same folder and there's an alphabet file
|
||||
# alongside them, use it.
|
||||
|
@ -185,6 +188,7 @@ class BaseSttConfig(Coqpit):
|
|||
"Will use this alphabet file for this run."
|
||||
)
|
||||
self.alphabet = Alphabet(str(possible_alphabet))
|
||||
self.effective_alphabet_path = possible_alphabet
|
||||
|
||||
if not self.alphabet:
|
||||
# Generate alphabet automatically from input dataset, but only if
|
||||
|
@ -199,6 +203,7 @@ class BaseSttConfig(Coqpit):
|
|||
characters, alphabet = create_alphabet_from_sources(sources)
|
||||
print(f"I Generated alphabet characters: {characters}.")
|
||||
self.alphabet = alphabet
|
||||
self.effective_alphabet_path = saved_checkpoint_alphabet_file
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Missing --alphabet_config_path flag. Couldn't find an alphabet file "
|
||||
|
@ -208,6 +213,17 @@ class BaseSttConfig(Coqpit):
|
|||
"be generated automatically."
|
||||
)
|
||||
|
||||
# Save flags next to checkpoints
|
||||
if not is_remote_path(self.save_checkpoint_dir):
|
||||
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(self.serialize(), fout, indent=2)
|
||||
|
||||
# Serialize alphabet alongside checkpoint
|
||||
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
||||
fout.write(self.alphabet.SerializeText())
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
||||
|
|
Loading…
Reference in New Issue