Merge pull request #2026 from coqui-ai/save-scorer-alphabet-savedmodel

Save Scorer and alphabet with SavedModel exports
This commit is contained in:
Reuben Morais 2021-11-19 20:07:32 +01:00 committed by GitHub
commit 154a67fb2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 20 deletions

View File

@ -240,6 +240,13 @@ def export_savedmodel():
)
builder.save()
# Copy scorer and alphabet alongside SavedModel
if Config.scorer_path:
print(f"Saving {Config.scorer_path} to {Config.export_dir}")
shutil.copy(Config.scorer_path, Config.export_dir)
shutil.copy(Config.effective_alphabet_path, Config.export_dir)
log_info(f"Exported SavedModel to {Config.export_dir}")

View File

@ -11,8 +11,6 @@ DESIRED_LOG_LEVEL = (
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import json
import shutil
import time
from datetime import datetime
from pathlib import Path
@ -59,11 +57,7 @@ from .util.config import (
)
from .util.feeding import create_dataset
from .util.helpers import check_ctcdecoder_version
from .util.io import (
is_remote_path,
open_remote,
remove_remote,
)
from .util.io import remove_remote
# Accuracy and Loss
@ -416,18 +410,6 @@ def train():
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev")
# Save flags next to checkpoints
if not is_remote_path(Config.save_checkpoint_dir):
os.makedirs(Config.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(Config.save_checkpoint_dir, "flags.txt")
with open_remote(flags_file, "w") as fout:
json.dump(Config.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
with open_remote(preserved_alphabet_file, "wb") as fout:
fout.write(Config.alphabet.SerializeText())
with tfv1.Session(config=Config.session_config) as session:
log_debug("Session opened.")

View File

@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function
import json
import os
import sys
from dataclasses import asdict, dataclass, field
@ -17,7 +18,7 @@ from .augmentations import NormalizeSampleRate, parse_augmentations
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
from .gpu import get_available_gpus
from .helpers import parse_file_size
from .io import path_exists_remote
from .io import is_remote_path, open_remote, path_exists_remote
class _ConfigSingleton:
@ -161,6 +162,7 @@ class BaseSttConfig(Coqpit):
self.alphabet = UTF8Alphabet()
elif self.alphabet_config_path:
self.alphabet = Alphabet(self.alphabet_config_path)
self.effective_alphabet_path = self.alphabet_config_path
elif os.path.exists(loaded_checkpoint_alphabet_file):
print(
"I --alphabet_config_path not specified, but found an alphabet file "
@ -168,6 +170,7 @@ class BaseSttConfig(Coqpit):
"Will use this alphabet file for this run."
)
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
self.effective_alphabet_path = loaded_checkpoint_alphabet_file
elif self.train_files and self.dev_files and self.test_files:
# If all subsets are in the same folder and there's an alphabet file
# alongside them, use it.
@ -185,6 +188,7 @@ class BaseSttConfig(Coqpit):
"Will use this alphabet file for this run."
)
self.alphabet = Alphabet(str(possible_alphabet))
self.effective_alphabet_path = possible_alphabet
if not self.alphabet:
# Generate alphabet automatically from input dataset, but only if
@ -199,6 +203,7 @@ class BaseSttConfig(Coqpit):
characters, alphabet = create_alphabet_from_sources(sources)
print(f"I Generated alphabet characters: {characters}.")
self.alphabet = alphabet
self.effective_alphabet_path = saved_checkpoint_alphabet_file
else:
raise RuntimeError(
"Missing --alphabet_config_path flag. Couldn't find an alphabet file "
@ -208,6 +213,17 @@ class BaseSttConfig(Coqpit):
"be generated automatically."
)
# Save flags next to checkpoints
if not is_remote_path(self.save_checkpoint_dir):
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
with open_remote(flags_file, "w") as fout:
json.dump(self.serialize(), fout, indent=2)
# Serialize alphabet alongside checkpoint
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
fout.write(self.alphabet.SerializeText())
# Geometric Constants
# ===================