From 87f0a371b1c0c903f2cae54039753196bf70cce4 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 23 Aug 2021 16:02:13 +0200 Subject: [PATCH] Serialize alphabet alongside checkpoint --- native_client/alphabet.cc | 18 ++++++++++++++++++ native_client/alphabet.h | 3 +++ training/coqui_stt_training/train.py | 5 +++++ 3 files changed, 26 insertions(+) diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index 9abc65a5..bd231a45 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -69,6 +69,24 @@ Alphabet::init(const char *config_file) return 0; } +std::string +Alphabet::SerializeText() +{ + std::stringstream out; + + out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n" + << "# associated with a numeric label.\n" + << "# A line that starts with # is a comment. You can escape it with \\# if you wish\n" + << "# to use '#' as a label.\n"; + + for (int label = 0; label < size_; ++label) { + out << label_to_str_[label] << "\n"; + } + + out << "# The last (non-comment) line needs to end with a newline.\n"; + return out.str(); +} + std::string Alphabet::Serialize() { diff --git a/native_client/alphabet.h b/native_client/alphabet.h index f402cc0d..3eb26790 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -22,6 +22,9 @@ public: // Serialize alphabet into a binary buffer. std::string Serialize(); + // Serialize alphabet into a text representation (ie. config file read by `init`) + std::string SerializeText(); + // Deserialize alphabet from a binary buffer. int Deserialize(const char* buffer, const int buffer_size); diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 6032184c..67079474 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -429,6 +429,11 @@ def train(): with open_remote(flags_file, "w") as fout: json.dump(Config.serialize(), fout, indent=2) + # Serialize alphabet alongside checkpoint + preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt") + with open_remote(preserved_alphabet_file, "wb") as fout: + fout.write(Config.alphabet.SerializeText()) + with tfv1.Session(config=Config.session_config) as session: log_debug("Session opened.")