Save Scorer and alphabet with SavedModel exports
This commit is contained in:
parent
3020949075
commit
d6456ae4aa
@ -240,6 +240,13 @@ def export_savedmodel():
|
|||||||
)
|
)
|
||||||
|
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
|
# Copy scorer and alphabet alongside SavedModel
|
||||||
|
if Config.scorer_path:
|
||||||
|
print(f"Saving {Config.scorer_path} to {Config.export_dir}")
|
||||||
|
shutil.copy(Config.scorer_path, Config.export_dir)
|
||||||
|
shutil.copy(Config.effective_alphabet_path, Config.export_dir)
|
||||||
|
|
||||||
log_info(f"Exported SavedModel to {Config.export_dir}")
|
log_info(f"Exported SavedModel to {Config.export_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,8 +11,6 @@ DESIRED_LOG_LEVEL = (
|
|||||||
)
|
)
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -59,11 +57,7 @@ from .util.config import (
|
|||||||
)
|
)
|
||||||
from .util.feeding import create_dataset
|
from .util.feeding import create_dataset
|
||||||
from .util.helpers import check_ctcdecoder_version
|
from .util.helpers import check_ctcdecoder_version
|
||||||
from .util.io import (
|
from .util.io import remove_remote
|
||||||
is_remote_path,
|
|
||||||
open_remote,
|
|
||||||
remove_remote,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Accuracy and Loss
|
# Accuracy and Loss
|
||||||
@ -416,18 +410,6 @@ def train():
|
|||||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||||
best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev")
|
best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev")
|
||||||
|
|
||||||
# Save flags next to checkpoints
|
|
||||||
if not is_remote_path(Config.save_checkpoint_dir):
|
|
||||||
os.makedirs(Config.save_checkpoint_dir, exist_ok=True)
|
|
||||||
flags_file = os.path.join(Config.save_checkpoint_dir, "flags.txt")
|
|
||||||
with open_remote(flags_file, "w") as fout:
|
|
||||||
json.dump(Config.serialize(), fout, indent=2)
|
|
||||||
|
|
||||||
# Serialize alphabet alongside checkpoint
|
|
||||||
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
|
|
||||||
with open_remote(preserved_alphabet_file, "wb") as fout:
|
|
||||||
fout.write(Config.alphabet.SerializeText())
|
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
log_debug("Session opened.")
|
log_debug("Session opened.")
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
@ -17,7 +18,7 @@ from .augmentations import NormalizeSampleRate, parse_augmentations
|
|||||||
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
|
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
|
||||||
from .gpu import get_available_gpus
|
from .gpu import get_available_gpus
|
||||||
from .helpers import parse_file_size
|
from .helpers import parse_file_size
|
||||||
from .io import path_exists_remote
|
from .io import is_remote_path, open_remote, path_exists_remote
|
||||||
|
|
||||||
|
|
||||||
class _ConfigSingleton:
|
class _ConfigSingleton:
|
||||||
@ -161,6 +162,7 @@ class BaseSttConfig(Coqpit):
|
|||||||
self.alphabet = UTF8Alphabet()
|
self.alphabet = UTF8Alphabet()
|
||||||
elif self.alphabet_config_path:
|
elif self.alphabet_config_path:
|
||||||
self.alphabet = Alphabet(self.alphabet_config_path)
|
self.alphabet = Alphabet(self.alphabet_config_path)
|
||||||
|
self.effective_alphabet_path = self.alphabet_config_path
|
||||||
elif os.path.exists(loaded_checkpoint_alphabet_file):
|
elif os.path.exists(loaded_checkpoint_alphabet_file):
|
||||||
print(
|
print(
|
||||||
"I --alphabet_config_path not specified, but found an alphabet file "
|
"I --alphabet_config_path not specified, but found an alphabet file "
|
||||||
@ -168,6 +170,7 @@ class BaseSttConfig(Coqpit):
|
|||||||
"Will use this alphabet file for this run."
|
"Will use this alphabet file for this run."
|
||||||
)
|
)
|
||||||
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
|
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
|
||||||
|
self.effective_alphabet_path = loaded_checkpoint_alphabet_file
|
||||||
elif self.train_files and self.dev_files and self.test_files:
|
elif self.train_files and self.dev_files and self.test_files:
|
||||||
# If all subsets are in the same folder and there's an alphabet file
|
# If all subsets are in the same folder and there's an alphabet file
|
||||||
# alongside them, use it.
|
# alongside them, use it.
|
||||||
@ -185,6 +188,7 @@ class BaseSttConfig(Coqpit):
|
|||||||
"Will use this alphabet file for this run."
|
"Will use this alphabet file for this run."
|
||||||
)
|
)
|
||||||
self.alphabet = Alphabet(str(possible_alphabet))
|
self.alphabet = Alphabet(str(possible_alphabet))
|
||||||
|
self.effective_alphabet_path = possible_alphabet
|
||||||
|
|
||||||
if not self.alphabet:
|
if not self.alphabet:
|
||||||
# Generate alphabet automatically from input dataset, but only if
|
# Generate alphabet automatically from input dataset, but only if
|
||||||
@ -199,6 +203,7 @@ class BaseSttConfig(Coqpit):
|
|||||||
characters, alphabet = create_alphabet_from_sources(sources)
|
characters, alphabet = create_alphabet_from_sources(sources)
|
||||||
print(f"I Generated alphabet characters: {characters}.")
|
print(f"I Generated alphabet characters: {characters}.")
|
||||||
self.alphabet = alphabet
|
self.alphabet = alphabet
|
||||||
|
self.effective_alphabet_path = saved_checkpoint_alphabet_file
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Missing --alphabet_config_path flag. Couldn't find an alphabet file "
|
"Missing --alphabet_config_path flag. Couldn't find an alphabet file "
|
||||||
@ -208,6 +213,17 @@ class BaseSttConfig(Coqpit):
|
|||||||
"be generated automatically."
|
"be generated automatically."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save flags next to checkpoints
|
||||||
|
if not is_remote_path(self.save_checkpoint_dir):
|
||||||
|
os.makedirs(self.save_checkpoint_dir, exist_ok=True)
|
||||||
|
flags_file = os.path.join(self.save_checkpoint_dir, "flags.txt")
|
||||||
|
with open_remote(flags_file, "w") as fout:
|
||||||
|
json.dump(self.serialize(), fout, indent=2)
|
||||||
|
|
||||||
|
# Serialize alphabet alongside checkpoint
|
||||||
|
with open_remote(saved_checkpoint_alphabet_file, "wb") as fout:
|
||||||
|
fout.write(self.alphabet.SerializeText())
|
||||||
|
|
||||||
# Geometric Constants
|
# Geometric Constants
|
||||||
# ===================
|
# ===================
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user