diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index bd231a45..a2a00dc0 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -69,6 +69,21 @@ Alphabet::init(const char *config_file) return 0; } +void +Alphabet::InitFromLabels(const std::vector& labels) +{ + space_label_ = -2; + size_ = labels.size(); + for (int i = 0; i < size_; ++i) { + const std::string& label = labels[i]; + if (label == " ") { + space_label_ = i; + } + label_to_str_[i] = label; + str_to_label_[label] = i; + } +} + std::string Alphabet::SerializeText() { diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 3eb26790..ad75dfc1 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -19,6 +19,9 @@ public: virtual int init(const char *config_file); + // Initialize directly from sequence of labels. + void InitFromLabels(const std::vector& labels); + // Serialize alphabet into a binary buffer. std::string Serialize(); diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index 93365c80..82cdd308 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -45,13 +45,17 @@ class Scorer(swigwrapper.Scorer): class Alphabet(swigwrapper.Alphabet): """Convenience wrapper for Alphabet which calls init in the constructor""" - def __init__(self, config_path): + def __init__(self, config_path=None): super(Alphabet, self).__init__() - err = self.init(config_path.encode("utf-8")) - if err != 0: - raise ValueError( - "Alphabet initialization failed with error code 0x{:X}".format(err) - ) + if config_path: + err = self.init(config_path.encode("utf-8")) + if err != 0: + raise ValueError( + "Alphabet initialization failed with error code 0x{:X}".format(err) + ) + + def InitFromLabels(self, data): + return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data]) def CanEncodeSingle(self, input): """ diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index f954ff0c..52619d35 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -11,11 +11,13 @@ from attrdict import AttrDict from coqpit import MISSING, Coqpit, check_argument from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet from xdg import BaseDirectory as xdg +from tqdm import tqdm from .augmentations import NormalizeSampleRate, parse_augmentations from .gpu import get_available_gpus from .helpers import parse_file_size from .io import path_exists_remote +from .sample_collections import samples_from_sources class _ConfigSingleton: @@ -120,9 +122,12 @@ class _SttConfig(Coqpit): # If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified, # look for alphabet file alongside loaded checkpoint. - checkpoint_alphabet_file = os.path.join( + loaded_checkpoint_alphabet_file = os.path.join( self.load_checkpoint_dir, "alphabet.txt" ) + saved_checkpoint_alphabet_file = os.path.join( + self.save_checkpoint_dir, "alphabet.txt" + ) if self.bytes_output_mode and self.alphabet_config_path: raise RuntimeError( @@ -132,15 +137,42 @@ class _SttConfig(Coqpit): self.alphabet = UTF8Alphabet() elif self.alphabet_config_path: self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path)) - elif os.path.exists(checkpoint_alphabet_file): + elif os.path.exists(loaded_checkpoint_alphabet_file): print( "I --alphabet_config_path not specified, but found an alphabet file " - f"alongside specified checkpoint ({checkpoint_alphabet_file}).\n" - "I Will use this alphabet file for this run." + f"alongside specified checkpoint ({loaded_checkpoint_alphabet_file}). " + "Will use this alphabet file for this run." ) - self.alphabet = Alphabet(checkpoint_alphabet_file) + self.alphabet = Alphabet(loaded_checkpoint_alphabet_file) + elif self.train_files and self.dev_files and self.test_files: + # Generate alphabet automatically from input dataset, but only if + # fully specified, to avoid confusion in case a missing set has extra + # characters. + print( + "I --alphabet_config_path not specified, but all input datasets are " + "present (--train_files, --dev_files, --test_files). An alphabet " + "will be generated automatically from the data and placed alongside " + f"the checkpoint ({saved_checkpoint_alphabet_file})." + ) + characters = set() + for sample in tqdm( + samples_from_sources( + self.train_files + self.dev_files + self.test_files + ) + ): + characters |= set(sample.transcript) + characters = list(sorted(characters)) + print(f"I Generated alphabet characters: {characters}.") + self.alphabet = Alphabet() + self.alphabet.InitFromLabels(characters) else: - raise RuntimeError("Missing --alphabet_config_path flag, can't continue") + raise RuntimeError( + "Missing --alphabet_config_path flag. Couldn't find an alphabet file\n" + "alongside checkpoint, and input datasets are not fully specified\n" + "(--train_files, --dev_files, --test_files), so can't generate an alphabet.\n" + "Either specify an alphabet file or fully specify the dataset, so one will\n" + "be generated automatically." + ) # Geometric Constants # ===================