Generate and save alphabet automatically if dataset is fully specified
This commit is contained in:
parent
2b5a844c05
commit
02adea2d50
|
@ -69,6 +69,21 @@ Alphabet::init(const char *config_file)
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
Alphabet::InitFromLabels(const std::vector<std::string>& labels)
|
||||||
|
{
|
||||||
|
space_label_ = -2;
|
||||||
|
size_ = labels.size();
|
||||||
|
for (int i = 0; i < size_; ++i) {
|
||||||
|
const std::string& label = labels[i];
|
||||||
|
if (label == " ") {
|
||||||
|
space_label_ = i;
|
||||||
|
}
|
||||||
|
label_to_str_[i] = label;
|
||||||
|
str_to_label_[label] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string
|
std::string
|
||||||
Alphabet::SerializeText()
|
Alphabet::SerializeText()
|
||||||
{
|
{
|
||||||
|
|
|
@ -19,6 +19,9 @@ public:
|
||||||
|
|
||||||
virtual int init(const char *config_file);
|
virtual int init(const char *config_file);
|
||||||
|
|
||||||
|
// Initialize directly from sequence of labels.
|
||||||
|
void InitFromLabels(const std::vector<std::string>& labels);
|
||||||
|
|
||||||
// Serialize alphabet into a binary buffer.
|
// Serialize alphabet into a binary buffer.
|
||||||
std::string Serialize();
|
std::string Serialize();
|
||||||
|
|
||||||
|
|
|
@ -45,14 +45,18 @@ class Scorer(swigwrapper.Scorer):
|
||||||
class Alphabet(swigwrapper.Alphabet):
|
class Alphabet(swigwrapper.Alphabet):
|
||||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||||
|
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path=None):
|
||||||
super(Alphabet, self).__init__()
|
super(Alphabet, self).__init__()
|
||||||
|
if config_path:
|
||||||
err = self.init(config_path.encode("utf-8"))
|
err = self.init(config_path.encode("utf-8"))
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Alphabet initialization failed with error code 0x{:X}".format(err)
|
"Alphabet initialization failed with error code 0x{:X}".format(err)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def InitFromLabels(self, data):
|
||||||
|
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
|
||||||
|
|
||||||
def CanEncodeSingle(self, input):
|
def CanEncodeSingle(self, input):
|
||||||
"""
|
"""
|
||||||
Returns true if the single character/output class has a corresponding label
|
Returns true if the single character/output class has a corresponding label
|
||||||
|
|
|
@ -11,11 +11,13 @@ from attrdict import AttrDict
|
||||||
from coqpit import MISSING, Coqpit, check_argument
|
from coqpit import MISSING, Coqpit, check_argument
|
||||||
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
|
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
|
||||||
from xdg import BaseDirectory as xdg
|
from xdg import BaseDirectory as xdg
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .augmentations import NormalizeSampleRate, parse_augmentations
|
from .augmentations import NormalizeSampleRate, parse_augmentations
|
||||||
from .gpu import get_available_gpus
|
from .gpu import get_available_gpus
|
||||||
from .helpers import parse_file_size
|
from .helpers import parse_file_size
|
||||||
from .io import path_exists_remote
|
from .io import path_exists_remote
|
||||||
|
from .sample_collections import samples_from_sources
|
||||||
|
|
||||||
|
|
||||||
class _ConfigSingleton:
|
class _ConfigSingleton:
|
||||||
|
@ -120,9 +122,12 @@ class _SttConfig(Coqpit):
|
||||||
|
|
||||||
# If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified,
|
# If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified,
|
||||||
# look for alphabet file alongside loaded checkpoint.
|
# look for alphabet file alongside loaded checkpoint.
|
||||||
checkpoint_alphabet_file = os.path.join(
|
loaded_checkpoint_alphabet_file = os.path.join(
|
||||||
self.load_checkpoint_dir, "alphabet.txt"
|
self.load_checkpoint_dir, "alphabet.txt"
|
||||||
)
|
)
|
||||||
|
saved_checkpoint_alphabet_file = os.path.join(
|
||||||
|
self.save_checkpoint_dir, "alphabet.txt"
|
||||||
|
)
|
||||||
|
|
||||||
if self.bytes_output_mode and self.alphabet_config_path:
|
if self.bytes_output_mode and self.alphabet_config_path:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -132,15 +137,42 @@ class _SttConfig(Coqpit):
|
||||||
self.alphabet = UTF8Alphabet()
|
self.alphabet = UTF8Alphabet()
|
||||||
elif self.alphabet_config_path:
|
elif self.alphabet_config_path:
|
||||||
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
|
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
|
||||||
elif os.path.exists(checkpoint_alphabet_file):
|
elif os.path.exists(loaded_checkpoint_alphabet_file):
|
||||||
print(
|
print(
|
||||||
"I --alphabet_config_path not specified, but found an alphabet file "
|
"I --alphabet_config_path not specified, but found an alphabet file "
|
||||||
f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n"
|
f"alongside specified checkpoint ({loaded_checkpoint_alphabet_file}). "
|
||||||
"I Will use this alphabet file for this run."
|
"Will use this alphabet file for this run."
|
||||||
)
|
)
|
||||||
self.alphabet = Alphabet(checkpoint_alphabet_file)
|
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
|
||||||
|
elif self.train_files and self.dev_files and self.test_files:
|
||||||
|
# Generate alphabet automatically from input dataset, but only if
|
||||||
|
# fully specified, to avoid confusion in case a missing set has extra
|
||||||
|
# characters.
|
||||||
|
print(
|
||||||
|
"I --alphabet_config_path not specified, but all input datasets are "
|
||||||
|
"present (--train_files, --dev_files, --test_files). An alphabet "
|
||||||
|
"will be generated automatically from the data and placed alongside "
|
||||||
|
f"the checkpoint ({saved_checkpoint_alphabet_file})."
|
||||||
|
)
|
||||||
|
characters = set()
|
||||||
|
for sample in tqdm(
|
||||||
|
samples_from_sources(
|
||||||
|
self.train_files + self.dev_files + self.test_files
|
||||||
|
)
|
||||||
|
):
|
||||||
|
characters |= set(sample.transcript)
|
||||||
|
characters = list(sorted(characters))
|
||||||
|
print(f"I Generated alphabet characters: {characters}.")
|
||||||
|
self.alphabet = Alphabet()
|
||||||
|
self.alphabet.InitFromLabels(characters)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Missing --alphabet_config_path flag, can't continue")
|
raise RuntimeError(
|
||||||
|
"Missing --alphabet_config_path flag. Couldn't find an alphabet file\n"
|
||||||
|
"alongside checkpoint, and input datasets are not fully specified\n"
|
||||||
|
"(--train_files, --dev_files, --test_files), so can't generate an alphabet.\n"
|
||||||
|
"Either specify an alphabet file or fully specify the dataset, so one will\n"
|
||||||
|
"be generated automatically."
|
||||||
|
)
|
||||||
|
|
||||||
# Geometric Constants
|
# Geometric Constants
|
||||||
# ===================
|
# ===================
|
||||||
|
|
Loading…
Reference in New Issue