diff --git a/.gitattributes b/.gitattributes index 79831978..aa841af6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ data/lm/kenlm.scorer filter=lfs diff=lfs merge=lfs -text -.github/actions/check_artifact_exists/dist/index.js binary \ No newline at end of file +.github/actions/check_artifact_exists/dist/index.js binary diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 94a37cae..ceed5544 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,3 @@ repos: - id: isort name: isort (pyi) types: [pyi] - - repo: https://github.com/pycqa/pylint - rev: v2.8.2 - hooks: - - id: pylint diff --git a/BIBLIOGRAPHY.md b/BIBLIOGRAPHY.md index f675f38b..76f4885f 100644 --- a/BIBLIOGRAPHY.md +++ b/BIBLIOGRAPHY.md @@ -3,16 +3,16 @@ This file contains a list of papers in chronological order that have been publis To appear ========== -* Raghuveer Peri, Haoqi Li, Krishna Somandepalli, Arindam Jati, Shrikanth Narayanan (2020) "An empirical analysis of information encoded in disentangled neural speaker representations". +* Raghuveer Peri, Haoqi Li, Krishna Somandepalli, Arindam Jati, Shrikanth Narayanan (2020) "An empirical analysis of information encoded in disentangled neural speaker representations". * Rosana Ardila, Megan Branson, Kelly Davis, Michael Henretty, Michael Kohler, Josh Meyer, Reuben Morais, Lindsay Saunders, Francis M. Tyers, and Gregor Weber (2020) "Common Voice: A Massively-Multilingual Speech Corpus". -Published +Published ========== 2020 ---------- -* Nils Hjortnaes, Niko Partanen, Michael Rießler and Francis M. Tyers (2020) +* Nils Hjortnaes, Niko Partanen, Michael Rießler and Francis M. Tyers (2020) "Towards a Speech Recognizer for Komi, an Endangered and Low-Resource Uralic Language". *Proceedings of the 6th International Workshop on Computational Linguistics of Uralic Languages*. ``` @@ -72,5 +72,5 @@ Published booktitle = {2018 IEEE/ACM Machine Learning in HPC Environments (MLHPC)}, doi = {https://doi.org/10.1109/MLHPC.2018.8638637} year = 2018 -} +} ``` diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index bdb48cd1..4465920c 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -118,11 +118,11 @@ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. -Community Impact Guidelines were inspired by +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at -[https://www.contributor-covenant.org/faq][FAQ]. Translations are available +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org diff --git a/CODE_OWNERS.rst b/CODE_OWNERS.rst index 92150211..e7253b2f 100644 --- a/CODE_OWNERS.rst +++ b/CODE_OWNERS.rst @@ -112,5 +112,5 @@ Documentation .. Third party bindings -------------------- - + Hosted externally and owned by the individual authors. See the `list of third-party bindings `_ for more info. diff --git a/Dockerfile.build b/Dockerfile.build index 5f175113..e7d2e6b5 100644 --- a/Dockerfile.build +++ b/Dockerfile.build @@ -1,6 +1,6 @@ # Please refer to the USING documentation, "Dockerfile for building from source" -# Need devel version cause we need /usr/include/cudnn.h +# Need devel version cause we need /usr/include/cudnn.h FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 ARG STT_REPO=https://github.com/coqui-ai/STT.git diff --git a/README.rst b/README.rst index d2dce20c..29a75e12 100644 --- a/README.rst +++ b/README.rst @@ -9,14 +9,14 @@ .. |covenant-img| image:: https://img.shields.io/badge/Contributor%20Covenant-2.0-4baaaa.svg :target: CODE_OF_CONDUCT.md :alt: Contributor Covenant - + .. |gitter-img| image:: https://badges.gitter.im/coqui-ai/STT.svg :target: https://gitter.im/coqui-ai/STT?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge :alt: Gitter Room - + .. |doi| image:: https://zenodo.org/badge/344354127.svg :target: https://zenodo.org/badge/latestdoi/344354127 - + |doc-img| |covenant-img| |gitter-img| |doi| `👉 Subscribe to 🐸Coqui's Newsletter `_ @@ -31,16 +31,16 @@ * Streaming inference. * Multiple possible transcripts, each with an associated confidence score. * Real-time inference. -* Small-footprint acoustic model. -* Bindings for various programming languages. +* Small-footprint acoustic model. +* Bindings for various programming languages. Where to Ask Questions ---------------------- .. list-table:: - :widths: 25 25 + :widths: 25 25 :header-rows: 1 - + * - Type - Link * - 🚨 **Bug Reports** @@ -51,14 +51,14 @@ Where to Ask Questions - `Github Discussions `_ * - 💬 **General Discussion** - `Github Discussions `_ or `Gitter Room `_ - - + + Links & Resources ----------------- -.. list-table:: - :widths: 25 25 +.. list-table:: + :widths: 25 25 :header-rows: 1 - + * - Type - Link * - 📰 **Documentation** @@ -67,4 +67,3 @@ Links & Resources - `see the latest release on GitHub `_ * - 🤝 **Contribution Guidelines** - `CONTRIBUTING.rst `_ - diff --git a/bazel.patch b/bazel.patch index 1b2addd2..4a166036 100644 --- a/bazel.patch +++ b/bazel.patch @@ -9,23 +9,23 @@ index c7aa4cb63..e084bc27c 100644 +import java.io.PrintWriter; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; - + @@ -73,6 +74,8 @@ public final class FileWriteAction extends AbstractFileWriteAction { */ private final CharSequence fileContents; - + + private final Artifact output; + /** Minimum length (in chars) for content to be eligible for compression. */ private static final int COMPRESS_CHARS_THRESHOLD = 256; - + @@ -90,6 +93,7 @@ public final class FileWriteAction extends AbstractFileWriteAction { fileContents = new CompressedString((String) fileContents); } this.fileContents = fileContents; + this.output = output; } - + /** @@ -230,11 +234,32 @@ public final class FileWriteAction extends AbstractFileWriteAction { */ @@ -59,7 +59,7 @@ index c7aa4cb63..e084bc27c 100644 + computeKeyDebugWriter.close(); + return rv; } - + /** diff --git a/src/main/java/com/google/devtools/build/lib/analysis/actions/SpawnAction.java b/src/main/java/com/google/devtools/build/lib/analysis/actions/SpawnAction.java index 580788160..26883eb92 100644 @@ -74,9 +74,9 @@ index 580788160..26883eb92 100644 import java.util.Collections; import java.util.LinkedHashMap; @@ -91,6 +92,9 @@ public class SpawnAction extends AbstractAction implements ExecutionInfoSpecifie - + private final CommandLine argv; - + + private final Iterable inputs; + private final Iterable outputs; + @@ -91,10 +91,10 @@ index 580788160..26883eb92 100644 + this.inputs = inputs; + this.outputs = outputs; } - + @Override @@ -312,23 +319,89 @@ public class SpawnAction extends AbstractAction implements ExecutionInfoSpecifie - + @Override protected String computeKey() { + boolean genruleSetup = String.valueOf(Iterables.get(inputs, 0).getExecPath()).contains("genrule/genrule-setup.sh"); @@ -182,14 +182,14 @@ index 580788160..26883eb92 100644 + } + return rv; } - + @Override diff --git a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java index 3559fffde..3ba39617c 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java +++ b/src/main/java/com/google/devtools/build/lib/rules/cpp/CppCompileAction.java @@ -1111,10 +1111,30 @@ public class CppCompileAction extends AbstractAction - + @Override public String computeKey() { + // ".ckd" Compute Key Debug @@ -216,7 +216,7 @@ index 3559fffde..3ba39617c 100644 + for (Map.Entry entry : executionInfo.entrySet()) { + computeKeyDebugWriter.println("EXECINFO: " + entry.getKey() + "=" + entry.getValue()); + } - + // For the argv part of the cache key, ignore all compiler flags that explicitly denote module // file (.pcm) inputs. Depending on input discovery, some of the unused ones are removed from @@ -1124,6 +1144,9 @@ public class CppCompileAction extends AbstractAction @@ -226,7 +226,7 @@ index 3559fffde..3ba39617c 100644 + for (String input : compileCommandLine.getArgv(getInternalOutputFile(), null)) { + computeKeyDebugWriter.println("COMMAND: " + input); + } - + /* * getArgv() above captures all changes which affect the compilation @@ -1133,19 +1156,31 @@ public class CppCompileAction extends AbstractAction @@ -260,5 +260,5 @@ index 3559fffde..3ba39617c 100644 + computeKeyDebugWriter.close(); + return rv; } - + @Override diff --git a/bin/compare_samples.py b/bin/compare_samples.py index 3bef72ca..cb07cbd8 100755 --- a/bin/compare_samples.py +++ b/bin/compare_samples.py @@ -2,10 +2,10 @@ """ Tool for comparing two wav samples """ -import sys import argparse -import numpy as np +import sys +import numpy as np from coqui_stt_training.util.audio import AUDIO_TYPE_NP, mean_dbfs from coqui_stt_training.util.sample_collections import load_sample @@ -19,19 +19,29 @@ def compare_samples(): sample1 = load_sample(CLI_ARGS.sample1).unpack() sample2 = load_sample(CLI_ARGS.sample2).unpack() if sample1.audio_format != sample2.audio_format: - fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format)) + fail( + "Samples differ on: audio-format ({} and {})".format( + sample1.audio_format, sample2.audio_format + ) + ) if abs(sample1.duration - sample2.duration) > 0.001: - fail('Samples differ on: duration ({} and {})'.format(sample1.duration, sample2.duration)) + fail( + "Samples differ on: duration ({} and {})".format( + sample1.duration, sample2.duration + ) + ) sample1.change_audio_type(AUDIO_TYPE_NP) sample2.change_audio_type(AUDIO_TYPE_NP) samples = [sample1, sample2] largest = np.argmax([sample1.audio.shape[0], sample2.audio.shape[0]]) smallest = (largest + 1) % 2 - samples[largest].audio = samples[largest].audio[:len(samples[smallest].audio)] + samples[largest].audio = samples[largest].audio[: len(samples[smallest].audio)] audio_diff = samples[largest].audio - samples[smallest].audio diff_dbfs = mean_dbfs(audio_diff) - differ_msg = 'Samples differ on: sample data ({:0.2f} dB difference) '.format(diff_dbfs) - equal_msg = 'Samples are considered equal ({:0.2f} dB difference)'.format(diff_dbfs) + differ_msg = "Samples differ on: sample data ({:0.2f} dB difference) ".format( + diff_dbfs + ) + equal_msg = "Samples are considered equal ({:0.2f} dB difference)".format(diff_dbfs) if CLI_ARGS.if_differ: if diff_dbfs <= CLI_ARGS.threshold: fail(equal_msg) @@ -50,13 +60,17 @@ def handle_args(): ) parser.add_argument("sample1", help="Filename of sample 1 to compare") parser.add_argument("sample2", help="Filename of sample 2 to compare") - parser.add_argument("--threshold", type=float, default=-60.0, - help="dB of sample deltas above which they are considered different") + parser.add_argument( + "--threshold", + type=float, + default=-60.0, + help="dB of sample deltas above which they are considered different", + ) parser.add_argument( "--if-differ", action="store_true", help="If to succeed and return status code 0 on different signals and fail on equal ones (inverse check)." - "This will still fail on different formats or durations.", + "This will still fail on different formats or durations.", ) parser.add_argument( "--no-success-output", diff --git a/bin/data_set_tool.py b/bin/data_set_tool.py index 521dda21..bdff618d 100755 --- a/bin/data_set_tool.py +++ b/bin/data_set_tool.py @@ -1,19 +1,24 @@ #!/usr/bin/env python -''' +""" Tool for building a combined SDB or CSV sample-set from other sets Use 'python3 data_set_tool.py -h' for help -''' -import sys +""" import argparse -import progressbar +import sys from pathlib import Path +import progressbar from coqui_stt_training.util.audio import ( - AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS, + AUDIO_TYPE_PCM, AUDIO_TYPE_WAV, change_audio_types, ) +from coqui_stt_training.util.augmentations import ( + SampleAugmentation, + apply_sample_augmentations, + parse_augmentations, +) from coqui_stt_training.util.downloader import SIMPLE_BAR from coqui_stt_training.util.sample_collections import ( CSVWriter, @@ -21,101 +26,110 @@ from coqui_stt_training.util.sample_collections import ( TarWriter, samples_from_sources, ) -from coqui_stt_training.util.augmentations import ( - parse_augmentations, - apply_sample_augmentations, - SampleAugmentation -) -AUDIO_TYPE_LOOKUP = {'wav': AUDIO_TYPE_WAV, 'opus': AUDIO_TYPE_OPUS} +AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS} def build_data_set(): audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] augmentations = parse_augmentations(CLI_ARGS.augment) if any(not isinstance(a, SampleAugmentation) for a in augmentations): - print('Warning: Some of the specified augmentations will not get applied, as this tool only supports ' - 'overlay, codec, reverb, resample and volume.') + print( + "Warning: Some of the specified augmentations will not get applied, as this tool only supports " + "overlay, codec, reverb, resample and volume." + ) extension = Path(CLI_ARGS.target).suffix.lower() labeled = not CLI_ARGS.unlabeled - if extension == '.csv': - writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled) - elif extension == '.sdb': - writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled) - elif extension == '.tar': - writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include) - elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'): - writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include) + if extension == ".csv": + writer = CSVWriter( + CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled + ) + elif extension == ".sdb": + writer = DirectSDBWriter( + CLI_ARGS.target, audio_type=audio_type, labeled=labeled + ) + elif extension == ".tar": + writer = TarWriter( + CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include + ) + elif extension == ".tgz" or CLI_ARGS.target.lower().endswith(".tar.gz"): + writer = TarWriter( + CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include + ) else: - print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz') + print( + "Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz" + ) sys.exit(1) with writer: samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) num_samples = len(samples) if augmentations: - samples = apply_sample_augmentations(samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations) + samples = apply_sample_augmentations( + samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations + ) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for sample in bar(change_audio_types( + for sample in bar( + change_audio_types( samples, audio_type=audio_type, bitrate=CLI_ARGS.bitrate, - processes=CLI_ARGS.workers)): + processes=CLI_ARGS.workers, + ) + ): writer.add(sample) def handle_args(): parser = argparse.ArgumentParser( - description='Tool for building a combined SDB or CSV sample-set from other sets' + description="Tool for building a combined SDB or CSV sample-set from other sets" ) parser.add_argument( - 'sources', - nargs='+', - help='Source CSV and/or SDB files - ' - 'Note: For getting a correctly ordered target set, source SDBs have to have their samples ' - 'already ordered from shortest to longest.', + "sources", + nargs="+", + help="Source CSV and/or SDB files - " + "Note: For getting a correctly ordered target set, source SDBs have to have their samples " + "already ordered from shortest to longest.", ) + parser.add_argument("target", help="SDB, CSV or TAR(.gz) file to create") parser.add_argument( - 'target', - help='SDB, CSV or TAR(.gz) file to create' - ) - parser.add_argument( - '--audio-type', - default='opus', + "--audio-type", + default="opus", choices=AUDIO_TYPE_LOOKUP.keys(), - help='Audio representation inside target SDB', + help="Audio representation inside target SDB", ) parser.add_argument( - '--bitrate', + "--bitrate", type=int, - help='Bitrate for lossy compressed SDB samples like in case of --audio-type opus', + help="Bitrate for lossy compressed SDB samples like in case of --audio-type opus", ) parser.add_argument( - '--workers', type=int, default=None, help='Number of encoding SDB workers' + "--workers", type=int, default=None, help="Number of encoding SDB workers" ) parser.add_argument( - '--unlabeled', - action='store_true', - help='If to build an data-set with unlabeled (audio only) samples - ' - 'typically used for building noise augmentation corpora', + "--unlabeled", + action="store_true", + help="If to build an data-set with unlabeled (audio only) samples - " + "typically used for building noise augmentation corpora", ) parser.add_argument( - '--absolute-paths', - action='store_true', - help='If to reference samples by their absolute paths when writing CSV files', + "--absolute-paths", + action="store_true", + help="If to reference samples by their absolute paths when writing CSV files", ) parser.add_argument( - '--augment', - action='append', - help='Add an augmentation operation', + "--augment", + action="append", + help="Add an augmentation operation", ) parser.add_argument( - '--include', - action='append', - help='Adds a file to the root directory of .tar(.gz) targets', + "--include", + action="append", + help="Adds a file to the root directory of .tar(.gz) targets", ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": CLI_ARGS = handle_args() build_data_set() diff --git a/bin/graphdef_binary_to_text.py b/bin/graphdef_binary_to_text.py index 032d3836..83f32fdb 100755 --- a/bin/graphdef_binary_to_text.py +++ b/bin/graphdef_binary_to_text.py @@ -3,9 +3,10 @@ import sys -import tensorflow.compat.v1 as tfv1 from google.protobuf import text_format +import tensorflow.compat.v1 as tfv1 + def main(): # Load and export as string diff --git a/bin/import_aidatatang.py b/bin/import_aidatatang.py index 8eac7de6..f46f2f35 100755 --- a/bin/import_aidatatang.py +++ b/bin/import_aidatatang.py @@ -4,7 +4,6 @@ import os import tarfile import pandas - from coqui_stt_training.util.importers import get_importers_parser COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] diff --git a/bin/import_aishell.py b/bin/import_aishell.py index 3ca71f02..32ac7438 100755 --- a/bin/import_aishell.py +++ b/bin/import_aishell.py @@ -4,7 +4,6 @@ import os import tarfile import pandas - from coqui_stt_training.util.importers import get_importers_parser COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"] diff --git a/bin/import_ccpmf.py b/bin/import_ccpmf.py index 80c68d33..b3e1faad 100755 --- a/bin/import_ccpmf.py +++ b/bin/import_ccpmf.py @@ -5,21 +5,21 @@ Ministère de l'Économie, des Finances et de la Relance """ import csv -import sys +import decimal +import hashlib +import math import os -import progressbar +import re import subprocess +import sys +import unicodedata +import xml.etree.ElementTree as ET import zipfile from glob import glob from multiprocessing import Pool -import hashlib -import decimal -import math -import unicodedata -import re +import progressbar import sox -import xml.etree.ElementTree as ET try: from num2words import num2words @@ -27,19 +27,19 @@ except ImportError as ex: print("pip install num2words") sys.exit(1) -import requests import json +import requests +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.helpers import secs_to_hours from coqui_stt_training.util.importers import ( get_counter, - get_importers_parser, get_imported_samples, + get_importers_parser, get_validate_label, print_import_report, ) -from coqui_stt_ctcdecoder import Alphabet FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] SAMPLE_RATE = 16000 @@ -50,58 +50,187 @@ MIN_SECS = 0.85 DATASET_RELEASE_CSV = "https://data.economie.gouv.fr/explore/dataset/transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020/download/?format=csv&timezone=Europe/Berlin&lang=fr&use_labels_for_header=true&csv_separator=%3B" DATASET_RELEASE_SHA = [ - ("863d39a06a388c6491c6ff2f6450b151f38f1b57", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.001"), - ("2f3a0305aa04c61220bb00b5a4e553e45dbf12e1", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.002"), - ("5e55e9f1f844097349188ac875947e5a3d7fe9f1", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.003"), - ("8bf54842cf07948ca5915e27a8bd5fa5139c06ae", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.004"), - ("c8963504aadc015ac48f9af80058a0bb3440b94f", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.005"), - ("d95e225e908621d83ce4e9795fd108d9d310e244", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.006"), - ("de6ed9c2b0ee80ca879aae8ba7923cc93217d811", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.007"), - ("234283c47dacfcd4450d836c52c25f3e807fc5f2", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.008"), - ("4e6b67a688639bb72f8cd81782eaba604a8d32a6", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.009"), - ("4165a51389777c8af8e6253d87bdacb877e8b3b0", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.010"), - ("34322e7009780d97ef5bd02bf2f2c7a31f00baff", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.011"), - ("48c5be3b2ca9d6108d525da6a03e91d93a95dbac", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.012"), - ("87573172f506a189c2ebc633856fe11a2e9cd213", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.013"), - ("6ab2c9e508e9278d5129f023e018725c4a7c69e8", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.014"), - ("4f84df831ef46dce5d3ab3e21817687a2d8c12d0", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.015"), - ("e69bfb079885c299cb81080ef88b1b8b57158aa6", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.016"), - ("5f764ba788ee273981cf211b242c29b49ca22c5e", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.017"), - ("b6aa81a959525363223494830c1e7307d4c4bae6", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.018"), - ("91ddcf43c7bf113a6f2528b857c7ec22a50a148a", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.019"), - ("fa1b29273dd77b9a7494983a2f9ae52654b931d7", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.020"), - ("1113aef4f5e2be2f7fbf2d54b6c710c1c0e7135f", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.021"), - ("ce6420d5d0b6b5135ba559f83e1a82d4d615c470", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.022"), - ("d0976ed292ac24fcf1590d1ea195077c74b05471", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.023"), - ("ec746cd6af066f62d9bf8d3b2f89174783ff4e3c", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.024"), - ("570d9e1e84178e32fd867171d4b3aaecda1fd4fb", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.025"), - ("c29ccc7467a75b2cae3d7f2e9fbbb2ab276cb8ac", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.026"), - ("08406a51146d88e208704ce058c060a1e44efa50", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.027"), - ("199aedad733a78ea1e7d47def9c71c6fd5795e02", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.028"), - ("db856a068f92fb4f01f410bba42c7271de0f231a", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.029"), - ("e3c0135f16c6c9d25a09dcb4f99a685438a84740", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.030"), - ("e51b8bb9c0ae4339f98b4f21e6d29b825109f0ac", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.031"), - ("be5e80cbc49b59b31ae33c30576ef0e1a162d84e", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.032"), - ("501df58e3ff55fcfd75b93dab57566dc536948b8", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.033"), - ("1a114875811a8cdcb8d85a9f6dbee78be3e05131", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.034"), - ("465d824e7ee46448369182c0c28646d155a2249b", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.035"), - ("37f341b1b266d143eb73138c31cfff3201b9d619", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.036"), - ("9e7d8255987a8a77a90e0d4b55c8fd38b9fb5694", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.037"), - ("54886755630cb080a53098cb1b6c951c6714a143", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.038"), - ("4b7cbb0154697be795034f7a49712e882a97197a", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.039"), - ("c8e1e565a0e7a1f6ff1dbfcefe677aa74a41d2f2", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.040"), + ( + "863d39a06a388c6491c6ff2f6450b151f38f1b57", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.001", + ), + ( + "2f3a0305aa04c61220bb00b5a4e553e45dbf12e1", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.002", + ), + ( + "5e55e9f1f844097349188ac875947e5a3d7fe9f1", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.003", + ), + ( + "8bf54842cf07948ca5915e27a8bd5fa5139c06ae", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.004", + ), + ( + "c8963504aadc015ac48f9af80058a0bb3440b94f", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.005", + ), + ( + "d95e225e908621d83ce4e9795fd108d9d310e244", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.006", + ), + ( + "de6ed9c2b0ee80ca879aae8ba7923cc93217d811", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.007", + ), + ( + "234283c47dacfcd4450d836c52c25f3e807fc5f2", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.008", + ), + ( + "4e6b67a688639bb72f8cd81782eaba604a8d32a6", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.009", + ), + ( + "4165a51389777c8af8e6253d87bdacb877e8b3b0", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.010", + ), + ( + "34322e7009780d97ef5bd02bf2f2c7a31f00baff", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.011", + ), + ( + "48c5be3b2ca9d6108d525da6a03e91d93a95dbac", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.012", + ), + ( + "87573172f506a189c2ebc633856fe11a2e9cd213", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.013", + ), + ( + "6ab2c9e508e9278d5129f023e018725c4a7c69e8", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.014", + ), + ( + "4f84df831ef46dce5d3ab3e21817687a2d8c12d0", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.015", + ), + ( + "e69bfb079885c299cb81080ef88b1b8b57158aa6", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.016", + ), + ( + "5f764ba788ee273981cf211b242c29b49ca22c5e", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.017", + ), + ( + "b6aa81a959525363223494830c1e7307d4c4bae6", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.018", + ), + ( + "91ddcf43c7bf113a6f2528b857c7ec22a50a148a", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.019", + ), + ( + "fa1b29273dd77b9a7494983a2f9ae52654b931d7", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.020", + ), + ( + "1113aef4f5e2be2f7fbf2d54b6c710c1c0e7135f", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.021", + ), + ( + "ce6420d5d0b6b5135ba559f83e1a82d4d615c470", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.022", + ), + ( + "d0976ed292ac24fcf1590d1ea195077c74b05471", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.023", + ), + ( + "ec746cd6af066f62d9bf8d3b2f89174783ff4e3c", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.024", + ), + ( + "570d9e1e84178e32fd867171d4b3aaecda1fd4fb", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.025", + ), + ( + "c29ccc7467a75b2cae3d7f2e9fbbb2ab276cb8ac", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.026", + ), + ( + "08406a51146d88e208704ce058c060a1e44efa50", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.027", + ), + ( + "199aedad733a78ea1e7d47def9c71c6fd5795e02", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.028", + ), + ( + "db856a068f92fb4f01f410bba42c7271de0f231a", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.029", + ), + ( + "e3c0135f16c6c9d25a09dcb4f99a685438a84740", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.030", + ), + ( + "e51b8bb9c0ae4339f98b4f21e6d29b825109f0ac", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.031", + ), + ( + "be5e80cbc49b59b31ae33c30576ef0e1a162d84e", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.032", + ), + ( + "501df58e3ff55fcfd75b93dab57566dc536948b8", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.033", + ), + ( + "1a114875811a8cdcb8d85a9f6dbee78be3e05131", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.034", + ), + ( + "465d824e7ee46448369182c0c28646d155a2249b", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.035", + ), + ( + "37f341b1b266d143eb73138c31cfff3201b9d619", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.036", + ), + ( + "9e7d8255987a8a77a90e0d4b55c8fd38b9fb5694", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.037", + ), + ( + "54886755630cb080a53098cb1b6c951c6714a143", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.038", + ), + ( + "4b7cbb0154697be795034f7a49712e882a97197a", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.039", + ), + ( + "c8e1e565a0e7a1f6ff1dbfcefe677aa74a41d2f2", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.040", + ), ] + def _download_and_preprocess_data(csv_url, target_dir): - dataset_sources = os.path.join(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", "data.txt") + dataset_sources = os.path.join( + target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", "data.txt" + ) if os.path.exists(dataset_sources): return dataset_sources # Making path absolute target_dir = os.path.abspath(target_dir) - csv_ref = requests.get(csv_url).text.split('\r\n')[1:-1] + csv_ref = requests.get(csv_url).text.split("\r\n")[1:-1] for part in csv_ref: - part_filename = requests.head(part).headers.get("Content-Disposition").split(" ")[1].split("=")[1].replace('"', "") + part_filename = ( + requests.head(part) + .headers.get("Content-Disposition") + .split(" ")[1] + .split("=")[1] + .replace('"', "") + ) if not os.path.exists(os.path.join(target_dir, part_filename)): part_path = maybe_download(part_filename, target_dir, part) @@ -126,10 +255,18 @@ def _download_and_preprocess_data(csv_url, target_dir): assert csum == sha1 # Conditionally extract data - _maybe_extract(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip", "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020.zip") + _maybe_extract( + target_dir, + "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", + "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip", + "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020.zip", + ) # Produce source text for extraction / conversion - return _maybe_create_sources(os.path.join(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020")) + return _maybe_create_sources( + os.path.join(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020") + ) + def _maybe_extract(target_dir, extracted_data, archive, final): # If target_dir/extracted_data does not exist, extract archive in target_dir @@ -147,7 +284,10 @@ def _maybe_extract(target_dir, extracted_data, archive, final): subprocess.check_call(cmdline, shell=True, cwd=target_dir) assert os.path.exists(archive_path) - print('No directory "%s" - extracting archive %s ...' % (extracted_path, archive_path)) + print( + 'No directory "%s" - extracting archive %s ...' + % (extracted_path, archive_path) + ) with zipfile.ZipFile(archive_path) as zip_f: zip_f.extractall(extracted_path) @@ -156,6 +296,7 @@ def _maybe_extract(target_dir, extracted_data, archive, final): else: print('Found directory "%s" - not extracting it from archive.' % extracted_path) + def _maybe_create_sources(dir): dataset_sources = os.path.join(dir, "data.txt") MP3 = glob(os.path.join(dir, "**", "*.mp3")) @@ -168,8 +309,8 @@ def _maybe_create_sources(dir): for f_xml in XML: b_mp3 = os.path.splitext(os.path.basename(f_mp3))[0] b_xml = os.path.splitext(os.path.basename(f_xml))[0] - a_mp3 = b_mp3.split('_') - a_xml = b_xml.split('_') + a_mp3 = b_mp3.split("_") + a_xml = b_xml.split("_") score = 0 date_mp3 = a_mp3[0] date_xml = a_xml[0] @@ -178,7 +319,7 @@ def _maybe_create_sources(dir): continue for i in range(min(len(a_mp3), len(a_xml))): - if (a_mp3[i] == a_xml[i]): + if a_mp3[i] == a_xml[i]: score += 1 if score >= 1: @@ -187,7 +328,7 @@ def _maybe_create_sources(dir): # sort by score MP3_XML_Scores.sort(key=lambda x: x[2], reverse=True) for s_mp3, s_xml, score in MP3_XML_Scores: - #print(s_mp3, s_xml, score) + # print(s_mp3, s_xml, score) if score not in MP3_XML_Fin: MP3_XML_Fin[score] = {} @@ -208,13 +349,14 @@ def _maybe_create_sources(dir): if os.path.getsize(mp3) > 0 and os.path.getsize(xml) > 0: mp3 = os.path.relpath(mp3, dir) xml = os.path.relpath(xml, dir) - ds.write('{},{},{:0.2e}\n'.format(xml, mp3, 2.5e-4)) + ds.write("{},{},{:0.2e}\n".format(xml, mp3, 2.5e-4)) else: print("Empty file {} or {}".format(mp3, xml), file=sys.stderr) print("Missing XML pairs:", MP3, file=sys.stderr) return dataset_sources + def maybe_normalize_for_digits(label): # first, try to identify numbers like "50 000", "260 000" if " " in label: @@ -234,30 +376,44 @@ def maybe_normalize_for_digits(label): date_or_time = re.compile(r"(\d{1,2}):(\d{2}):?(\d{2})?") maybe_date_or_time = date_or_time.findall(s) if len(maybe_date_or_time) > 0: - maybe_hours = maybe_date_or_time[0][0] + maybe_hours = maybe_date_or_time[0][0] maybe_minutes = maybe_date_or_time[0][1] maybe_seconds = maybe_date_or_time[0][2] if len(maybe_seconds) > 0: - label = label.replace("{}:{}:{}".format(maybe_hours, maybe_minutes, maybe_seconds), "{} heures {} minutes et {} secondes".format(maybe_hours, maybe_minutes, maybe_seconds)) + label = label.replace( + "{}:{}:{}".format( + maybe_hours, maybe_minutes, maybe_seconds + ), + "{} heures {} minutes et {} secondes".format( + maybe_hours, maybe_minutes, maybe_seconds + ), + ) else: - label = label.replace("{}:{}".format(maybe_hours, maybe_minutes), "{} heures et {} minutes".format(maybe_hours, maybe_minutes)) + label = label.replace( + "{}:{}".format(maybe_hours, maybe_minutes), + "{} heures et {} minutes".format( + maybe_hours, maybe_minutes + ), + ) new_label = [] # pylint: disable=too-many-nested-blocks for s in label.split(" "): if any(i.isdigit() for i in s): - s = s.replace(",", ".") # num2words requires "." for floats - s = s.replace("\"", "") # clean some data, num2words would choke on 1959" + s = s.replace(",", ".") # num2words requires "." for floats + s = s.replace('"', "") # clean some data, num2words would choke on 1959" last_c = s[-1] - if not last_c.isdigit(): # num2words will choke on "0.6.", "24 ?" + if not last_c.isdigit(): # num2words will choke on "0.6.", "24 ?" s = s[:-1] - if any(i.isalpha() for i in s): # So we have any(isdigit()) **and** any(sialpha), like "3D" + if any( + i.isalpha() for i in s + ): # So we have any(isdigit()) **and** any(sialpha), like "3D" ns = [] for c in s: nc = c - if c.isdigit(): # convert "3" to "trois-" + if c.isdigit(): # convert "3" to "trois-" try: nc = num2words(c, lang="fr") + "-" except decimal.InvalidOperation as ex: @@ -274,22 +430,36 @@ def maybe_normalize_for_digits(label): new_label.append(s) return " ".join(new_label) + def maybe_normalize_for_specials_chars(label): label = label.replace("%", "pourcents") - label = label.replace("/", ", ") # clean intervals like 2019/2022 to "2019 2022" - label = label.replace("-", ", ") # clean intervals like 70-80 to "70 80" - label = label.replace("+", " plus ") # clean + and make it speakable - label = label.replace("€", " euros ") # clean euro symbol and make it speakable - label = label.replace("., ", ", ") # clean some strange "4.0., " (20181017_Innovation.xml) - label = label.replace("°", " degré ") # clean some strange "°5" (20181210_EtatsGeneraux-1000_fre_750_und.xml) - label = label.replace("...", ".") # remove ellipsis - label = label.replace("..", ".") # remove broken ellipsis - label = label.replace("m²", "mètre-carrés") # 20150616_Defi_Climat_3_wmv_0_fre_minefi.xml - label = label.replace("[end]", "") # broken tag in 20150123_Entretiens_Tresor_PGM_wmv_0_fre_minefi.xml - label = label.replace(u'\xB8c', " ç") # strange cedilla in 20150417_Printemps_Economie_2_wmv_0_fre_minefi.xml - label = label.replace("C0²", "CO 2") # 20121016_Syteme_sante_copie_wmv_0_fre_minefi.xml + label = label.replace("/", ", ") # clean intervals like 2019/2022 to "2019 2022" + label = label.replace("-", ", ") # clean intervals like 70-80 to "70 80" + label = label.replace("+", " plus ") # clean + and make it speakable + label = label.replace("€", " euros ") # clean euro symbol and make it speakable + label = label.replace( + "., ", ", " + ) # clean some strange "4.0., " (20181017_Innovation.xml) + label = label.replace( + "°", " degré " + ) # clean some strange "°5" (20181210_EtatsGeneraux-1000_fre_750_und.xml) + label = label.replace("...", ".") # remove ellipsis + label = label.replace("..", ".") # remove broken ellipsis + label = label.replace( + "m²", "mètre-carrés" + ) # 20150616_Defi_Climat_3_wmv_0_fre_minefi.xml + label = label.replace( + "[end]", "" + ) # broken tag in 20150123_Entretiens_Tresor_PGM_wmv_0_fre_minefi.xml + label = label.replace( + u"\xB8c", " ç" + ) # strange cedilla in 20150417_Printemps_Economie_2_wmv_0_fre_minefi.xml + label = label.replace( + "C0²", "CO 2" + ) # 20121016_Syteme_sante_copie_wmv_0_fre_minefi.xml return label + def maybe_normalize_for_anglicisms(label): label = label.replace("B2B", "B to B") label = label.replace("B2C", "B to C") @@ -297,12 +467,14 @@ def maybe_normalize_for_anglicisms(label): label = label.replace("@", "at ") return label + def maybe_normalize(label): label = maybe_normalize_for_specials_chars(label) label = maybe_normalize_for_anglicisms(label) label = maybe_normalize_for_digits(label) return label + def one_sample(sample): file_size = -1 frames = 0 @@ -316,14 +488,33 @@ def one_sample(sample): label = label_filter_fun(sample[5]) sample_id = sample[6] - _wav_filename = os.path.basename(audio_source.replace(".wav", "_{:06}.wav".format(sample_id))) + _wav_filename = os.path.basename( + audio_source.replace(".wav", "_{:06}.wav".format(sample_id)) + ) wav_fullname = os.path.join(target_dir, dataset_basename, _wav_filename) if not os.path.exists(wav_fullname): - subprocess.check_output(["ffmpeg", "-i", audio_source, "-ss", str(start_time), "-t", str(duration), "-c", "copy", wav_fullname], stdin=subprocess.DEVNULL, stderr=subprocess.STDOUT) + subprocess.check_output( + [ + "ffmpeg", + "-i", + audio_source, + "-ss", + str(start_time), + "-t", + str(duration), + "-c", + "copy", + wav_fullname, + ], + stdin=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) file_size = os.path.getsize(wav_fullname) - frames = int(subprocess.check_output(["soxi", "-s", wav_fullname], stderr=subprocess.STDOUT)) + frames = int( + subprocess.check_output(["soxi", "-s", wav_fullname], stderr=subprocess.STDOUT) + ) _counter = get_counter() _rows = [] @@ -334,13 +525,13 @@ def one_sample(sample): elif label is None: # Excluding samples that failed on label validation _counter["invalid_label"] += 1 - elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): + elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)): # Excluding samples that are too short to fit the transcript _counter["too_short"] += 1 - elif frames/SAMPLE_RATE < MIN_SECS: + elif frames / SAMPLE_RATE < MIN_SECS: # Excluding samples that are too short _counter["too_short"] += 1 - elif frames/SAMPLE_RATE > MAX_SECS: + elif frames / SAMPLE_RATE > MAX_SECS: # Excluding very long samples to keep a reasonable batch-size _counter["too_long"] += 1 else: @@ -352,56 +543,71 @@ def one_sample(sample): return (_counter, _rows) + def _maybe_import_data(xml_file, audio_source, target_dir, rel_tol=1e-1): dataset_basename = os.path.splitext(os.path.split(xml_file)[1])[0] wav_root = os.path.join(target_dir, dataset_basename) if not os.path.exists(wav_root): os.makedirs(wav_root) - source_frames = int(subprocess.check_output(["soxi", "-s", audio_source], stderr=subprocess.STDOUT)) + source_frames = int( + subprocess.check_output(["soxi", "-s", audio_source], stderr=subprocess.STDOUT) + ) print("Source audio length: %s" % secs_to_hours(source_frames / SAMPLE_RATE)) # Get audiofile path and transcript for each sentence in tsv samples = [] tree = ET.parse(xml_file) root = tree.getroot() - seq_id = 0 - this_time = 0.0 + seq_id = 0 + this_time = 0.0 this_duration = 0.0 - prev_time = 0.0 + prev_time = 0.0 prev_duration = 0.0 - this_text = "" + this_text = "" for child in root: if child.tag == "row": - cur_time = float(child.attrib["timestamp"]) + cur_time = float(child.attrib["timestamp"]) cur_duration = float(child.attrib["timedur"]) - cur_text = child.text + cur_text = child.text if this_time == 0.0: this_time = cur_time - delta = cur_time - (prev_time + prev_duration) + delta = cur_time - (prev_time + prev_duration) # rel_tol value is made from trial/error to try and compromise between: # - cutting enough to skip missing words # - not too short, not too long sentences - is_close = math.isclose(cur_time, this_time + this_duration, rel_tol=rel_tol) - is_short = ((this_duration + cur_duration + delta) < MAX_SECS) + is_close = math.isclose( + cur_time, this_time + this_duration, rel_tol=rel_tol + ) + is_short = (this_duration + cur_duration + delta) < MAX_SECS # when the previous element is close enough **and** this does not # go over MAX_SECS, we append content - if (is_close and is_short): + if is_close and is_short: this_duration += cur_duration + delta - this_text += cur_text + this_text += cur_text else: - samples.append((audio_source, target_dir, dataset_basename, this_time, this_duration, this_text, seq_id)) + samples.append( + ( + audio_source, + target_dir, + dataset_basename, + this_time, + this_duration, + this_text, + seq_id, + ) + ) - this_time = cur_time + this_time = cur_time this_duration = cur_duration - this_text = cur_text + this_text = cur_text seq_id += 1 - prev_time = cur_time + prev_time = cur_time prev_duration = cur_duration # Keep track of how many samples are good vs. problematic @@ -425,21 +631,27 @@ def _maybe_import_data(xml_file, audio_source, target_dir, rel_tol=1e-1): assert len(_rows) == imported_samples print_import_report(_counter, SAMPLE_RATE, MAX_SECS) - print("Import efficiency: %.1f%%" % ((_counter["total_time"] / source_frames)*100)) + print( + "Import efficiency: %.1f%%" % ((_counter["total_time"] / source_frames) * 100) + ) print("") return _counter, _rows + def _maybe_convert_wav(mp3_filename, _wav_filename): if not os.path.exists(_wav_filename): print("Converting {} to WAV file: {}".format(mp3_filename, _wav_filename)) transformer = sox.Transformer() - transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS, bitdepth=BIT_DEPTH) + transformer.convert( + samplerate=SAMPLE_RATE, n_channels=CHANNELS, bitdepth=BIT_DEPTH + ) try: transformer.build(mp3_filename, _wav_filename) except sox.core.SoxError: pass + def write_general_csv(target_dir, _rows, _counter): target_csv_template = os.path.join(target_dir, "ccpmf_{}.csv") with open(target_csv_template.format("train"), "w") as train_csv_file: # 80% @@ -461,7 +673,13 @@ def write_general_csv(target_dir, _rows, _counter): writer = dev_writer else: writer = train_writer - writer.writerow({"wav_filename": item[0], "wav_filesize": item[1], "transcript": item[2]}) + writer.writerow( + { + "wav_filename": item[0], + "wav_filesize": item[1], + "transcript": item[2], + } + ) print("") print("~~~~ FINAL STATISTICS ~~~~") @@ -469,11 +687,21 @@ def write_general_csv(target_dir, _rows, _counter): print("~~~~ (FINAL STATISTICS) ~~~~") print("") + if __name__ == "__main__": - PARSER = get_importers_parser(description="Import XML from Conference Centre for Economics, France") + PARSER = get_importers_parser( + description="Import XML from Conference Centre for Economics, France" + ) PARSER.add_argument("target_dir", help="Destination directory") - PARSER.add_argument("--filter_alphabet", help="Exclude samples with characters not in provided alphabet") - PARSER.add_argument("--normalize", action="store_true", help="Converts diacritic characters to their base ones") + PARSER.add_argument( + "--filter_alphabet", + help="Exclude samples with characters not in provided alphabet", + ) + PARSER.add_argument( + "--normalize", + action="store_true", + help="Converts diacritic characters to their base ones", + ) PARAMS = PARSER.parse_args() validate_label = get_validate_label(PARAMS) @@ -481,9 +709,11 @@ if __name__ == "__main__": def label_filter_fun(label): if PARAMS.normalize: - label = unicodedata.normalize("NFKD", label.strip()) \ - .encode("ascii", "ignore") \ + label = ( + unicodedata.normalize("NFKD", label.strip()) + .encode("ascii", "ignore") .decode("ascii", "ignore") + ) label = maybe_normalize(label) label = validate_label(label) if ALPHABET and label: @@ -493,7 +723,9 @@ if __name__ == "__main__": label = None return label - dataset_sources = _download_and_preprocess_data(csv_url=DATASET_RELEASE_CSV, target_dir=PARAMS.target_dir) + dataset_sources = _download_and_preprocess_data( + csv_url=DATASET_RELEASE_CSV, target_dir=PARAMS.target_dir + ) sources_root_dir = os.path.dirname(dataset_sources) all_counter = get_counter() all_rows = [] @@ -504,9 +736,14 @@ if __name__ == "__main__": this_mp3 = os.path.join(sources_root_dir, d[1]) this_rel = float(d[2]) - wav_filename = os.path.join(sources_root_dir, os.path.splitext(os.path.basename(this_mp3))[0] + ".wav") + wav_filename = os.path.join( + sources_root_dir, + os.path.splitext(os.path.basename(this_mp3))[0] + ".wav", + ) _maybe_convert_wav(this_mp3, wav_filename) - counter, rows = _maybe_import_data(this_xml, wav_filename, sources_root_dir, this_rel) + counter, rows = _maybe_import_data( + this_xml, wav_filename, sources_root_dir, this_rel + ) all_counter += counter all_rows += rows diff --git a/bin/import_cv.py b/bin/import_cv.py index a59c9a25..240c6a58 100755 --- a/bin/import_cv.py +++ b/bin/import_cv.py @@ -1,15 +1,14 @@ #!/usr/bin/env python import csv import os -import sys import subprocess +import sys import tarfile from glob import glob from multiprocessing import Pool import progressbar import sox - from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import ( get_counter, diff --git a/bin/import_cv2.py b/bin/import_cv2.py index af702385..48e5f908 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -14,7 +14,7 @@ from multiprocessing import Pool import progressbar import sox - +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR from coqui_stt_training.util.importers import ( get_counter, @@ -23,7 +23,6 @@ from coqui_stt_training.util.importers import ( get_validate_label, print_import_report, ) -from coqui_stt_ctcdecoder import Alphabet FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] SAMPLE_RATE = 16000 @@ -41,7 +40,11 @@ class LabelFilter: def filter(self, label): if self.normalize: - label = unicodedata.normalize("NFKD", label.strip()).encode("ascii", "ignore").decode("ascii", "ignore") + label = ( + unicodedata.normalize("NFKD", label.strip()) + .encode("ascii", "ignore") + .decode("ascii", "ignore") + ) label = self.validate_fun(label) if self.alphabet and label and not self.alphabet.CanEncode(label): label = None @@ -97,7 +100,15 @@ def one_sample(sample): return (counter, rows) -def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character=None, rows=None, exclude=None): +def _maybe_convert_set( + dataset, + tsv_dir, + audio_dir, + filter_obj, + space_after_every_character=None, + rows=None, + exclude=None, +): exclude_transcripts = set() exclude_speakers = set() if exclude is not None: @@ -116,7 +127,13 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever with open(input_tsv, encoding="utf-8") as input_tsv_file: reader = csv.DictReader(input_tsv_file, delimiter="\t") for row in reader: - samples.append((os.path.join(audio_dir, row["path"]), row["sentence"], row["client_id"])) + samples.append( + ( + os.path.join(audio_dir, row["path"]), + row["sentence"], + row["client_id"], + ) + ) counter = get_counter() num_samples = len(samples) @@ -124,7 +141,9 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever print("Importing mp3 files...") pool = Pool(initializer=init_worker, initargs=(PARAMS,)) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) - for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): + for i, processed in enumerate( + pool.imap_unordered(one_sample, samples), start=1 + ): counter += processed[0] rows += processed[1] bar.update(i) @@ -169,12 +188,20 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): exclude = [] for dataset in ["test", "dev", "train", "validated", "other"]: - set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character) + set_samples = _maybe_convert_set( + dataset, tsv_dir, audio_dir, space_after_every_character + ) if dataset in ["test", "dev"]: exclude += set_samples if dataset == "validated": - _maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character, - rows=set_samples, exclude=exclude) + _maybe_convert_set( + "train-all", + tsv_dir, + audio_dir, + space_after_every_character, + rows=set_samples, + exclude=exclude, + ) def _maybe_convert_wav(mp3_filename, wav_filename): @@ -212,7 +239,9 @@ def parse_args(): def main(): - audio_dir = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips") + audio_dir = ( + PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips") + ) _preprocess_data(PARAMS.tsv_dir, audio_dir, PARAMS.space_after_every_character) diff --git a/bin/import_fisher.py b/bin/import_fisher.py index 9c6f8a7b..acbee076 100755 --- a/bin/import_fisher.py +++ b/bin/import_fisher.py @@ -10,7 +10,6 @@ import unicodedata import librosa import pandas import soundfile # <= Has an external dependency on libsndfile - from coqui_stt_training.util.importers import validate_label_eng as validate_label # Prerequisite: Having the sph2pipe tool in your PATH: @@ -239,7 +238,7 @@ def _split_and_resample_wav(origAudio, start_time, stop_time, new_wav_file): def _split_sets(filelist): """ randomply split the datasets into train, validation, and test sets where the size of the - validation and test sets are determined by the `get_sample_size` function. + validation and test sets are determined by the `get_sample_size` function. """ random.shuffle(filelist) sample_size = get_sample_size(len(filelist)) @@ -261,8 +260,7 @@ def _split_sets(filelist): def get_sample_size(population_size): - """calculates the sample size for a 99% confidence and 1% margin of error - """ + """calculates the sample size for a 99% confidence and 1% margin of error""" margin_of_error = 0.01 fraction_picking = 0.50 z_score = 2.58 # Corresponds to confidence level 99% diff --git a/bin/import_freestmandarin.py b/bin/import_freestmandarin.py index f1838d91..efd10b7b 100755 --- a/bin/import_freestmandarin.py +++ b/bin/import_freestmandarin.py @@ -5,7 +5,6 @@ import tarfile import numpy as np import pandas - from coqui_stt_training.util.importers import get_importers_parser COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] diff --git a/bin/import_gram_vaani.py b/bin/import_gram_vaani.py index 80bf0241..14b937c6 100755 --- a/bin/import_gram_vaani.py +++ b/bin/import_gram_vaani.py @@ -9,10 +9,9 @@ import urllib from pathlib import Path import pandas as pd -from sox import Transformer - import swifter from coqui_stt_training.util.importers import get_importers_parser, get_validate_label +from sox import Transformer __version__ = "0.1.0" _logger = logging.getLogger(__name__) diff --git a/bin/import_ldc93s1.py b/bin/import_ldc93s1.py index 85088b93..f7ef3dd7 100755 --- a/bin/import_ldc93s1.py +++ b/bin/import_ldc93s1.py @@ -3,7 +3,6 @@ import os import sys import pandas - from coqui_stt_training.util.downloader import maybe_download diff --git a/bin/import_librivox.py b/bin/import_librivox.py index 491488fa..3469f7f9 100755 --- a/bin/import_librivox.py +++ b/bin/import_librivox.py @@ -9,10 +9,10 @@ import unicodedata import pandas import progressbar -from sox import Transformer -from tensorflow.python.platform import gfile - from coqui_stt_training.util.downloader import maybe_download +from sox import Transformer + +from tensorflow.python.platform import gfile SAMPLE_RATE = 16000 diff --git a/bin/import_lingua_libre.py b/bin/import_lingua_libre.py index 743e404d..dc1fa5bd 100755 --- a/bin/import_lingua_libre.py +++ b/bin/import_lingua_libre.py @@ -11,7 +11,7 @@ from multiprocessing import Pool import progressbar import sox - +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import ( get_counter, @@ -20,7 +20,6 @@ from coqui_stt_training.util.importers import ( get_validate_label, print_import_report, ) -from coqui_stt_ctcdecoder import Alphabet FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] SAMPLE_RATE = 16000 @@ -137,9 +136,15 @@ def _maybe_convert_sets(target_dir, extracted_data): pool.close() pool.join() - with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% - with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% - with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% + with open( + target_csv_template.format("train"), "w", encoding="utf-8", newline="" + ) as train_csv_file: # 80% + with open( + target_csv_template.format("dev"), "w", encoding="utf-8", newline="" + ) as dev_csv_file: # 10% + with open( + target_csv_template.format("test"), "w", encoding="utf-8", newline="" + ) as test_csv_file: # 10% train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer.writeheader() dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) @@ -179,7 +184,9 @@ def _maybe_convert_sets(target_dir, extracted_data): def _maybe_convert_wav(ogg_filename, wav_filename): if not os.path.exists(wav_filename): transformer = sox.Transformer() - transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH) + transformer.convert( + samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH + ) try: transformer.build(ogg_filename, wav_filename) except sox.core.SoxError as ex: diff --git a/bin/import_m-ailabs.py b/bin/import_m-ailabs.py index 68417500..9a00a795 100755 --- a/bin/import_m-ailabs.py +++ b/bin/import_m-ailabs.py @@ -9,7 +9,7 @@ from glob import glob from multiprocessing import Pool import progressbar - +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import ( get_counter, @@ -18,7 +18,6 @@ from coqui_stt_training.util.importers import ( get_validate_label, print_import_report, ) -from coqui_stt_ctcdecoder import Alphabet FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] SAMPLE_RATE = 16000 @@ -60,9 +59,20 @@ def one_sample(sample): file_size = -1 frames = 0 if os.path.exists(wav_filename): - tmp_filename = os.path.splitext(wav_filename)[0]+'.tmp.wav' + tmp_filename = os.path.splitext(wav_filename)[0] + ".tmp.wav" subprocess.check_call( - ['sox', wav_filename, '-r', str(SAMPLE_RATE), '-c', '1', '-b', '16', tmp_filename], stderr=subprocess.STDOUT + [ + "sox", + wav_filename, + "-r", + str(SAMPLE_RATE), + "-c", + "1", + "-b", + "16", + tmp_filename, + ], + stderr=subprocess.STDOUT, ) os.rename(tmp_filename, wav_filename) file_size = os.path.getsize(wav_filename) @@ -138,9 +148,15 @@ def _maybe_convert_sets(target_dir, extracted_data): pool.close() pool.join() - with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% - with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% - with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% + with open( + target_csv_template.format("train"), "w", encoding="utf-8", newline="" + ) as train_csv_file: # 80% + with open( + target_csv_template.format("dev"), "w", encoding="utf-8", newline="" + ) as dev_csv_file: # 10% + with open( + target_csv_template.format("test"), "w", encoding="utf-8", newline="" + ) as test_csv_file: # 10% train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer.writeheader() dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) diff --git a/bin/import_magicdata.py b/bin/import_magicdata.py index 8b289804..94e430df 100755 --- a/bin/import_magicdata.py +++ b/bin/import_magicdata.py @@ -5,7 +5,6 @@ import tarfile import wave import pandas - from coqui_stt_training.util.importers import get_importers_parser COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] diff --git a/bin/import_mls_english.py b/bin/import_mls_english.py index 5ff83e3d..4e1e14a3 100644 --- a/bin/import_mls_english.py +++ b/bin/import_mls_english.py @@ -2,10 +2,9 @@ import argparse import ctypes import os +from pathlib import Path import pandas - -from pathlib import Path from tqdm import tqdm diff --git a/bin/import_primewords.py b/bin/import_primewords.py index 4643bd39..8247444c 100755 --- a/bin/import_primewords.py +++ b/bin/import_primewords.py @@ -6,7 +6,6 @@ import tarfile import numpy as np import pandas - from coqui_stt_training.util.importers import get_importers_parser COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] diff --git a/bin/import_slr57.py b/bin/import_slr57.py index ba42881e..3060f5ec 100755 --- a/bin/import_slr57.py +++ b/bin/import_slr57.py @@ -8,7 +8,7 @@ from glob import glob from multiprocessing import Pool import progressbar - +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import ( get_counter, @@ -17,7 +17,6 @@ from coqui_stt_training.util.importers import ( get_validate_label, print_import_report, ) -from coqui_stt_ctcdecoder import Alphabet FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] SAMPLE_RATE = 16000 @@ -157,9 +156,15 @@ def _maybe_convert_sets(target_dir, extracted_data): pool.close() pool.join() - with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% - with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% - with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% + with open( + target_csv_template.format("train"), "w", encoding="utf-8", newline="" + ) as train_csv_file: # 80% + with open( + target_csv_template.format("dev"), "w", encoding="utf-8", newline="" + ) as dev_csv_file: # 10% + with open( + target_csv_template.format("test"), "w", encoding="utf-8", newline="" + ) as test_csv_file: # 10% train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer.writeheader() dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) diff --git a/bin/import_swb.py b/bin/import_swb.py index b192d9f8..068a2f73 100755 --- a/bin/import_swb.py +++ b/bin/import_swb.py @@ -16,7 +16,6 @@ import librosa import pandas import requests import soundfile # <= Has an external dependency on libsndfile - from coqui_stt_training.util.importers import validate_label_eng as validate_label # ARCHIVE_NAME refers to ISIP alignments from 01/29/03 @@ -293,7 +292,7 @@ def _split_wav(origAudio, start_time, stop_time, new_wav_file): def _split_sets(filelist): """ randomply split the datasets into train, validation, and test sets where the size of the - validation and test sets are determined by the `get_sample_size` function. + validation and test sets are determined by the `get_sample_size` function. """ random.shuffle(filelist) sample_size = get_sample_size(len(filelist)) @@ -315,8 +314,7 @@ def _split_sets(filelist): def get_sample_size(population_size): - """calculates the sample size for a 99% confidence and 1% margin of error - """ + """calculates the sample size for a 99% confidence and 1% margin of error""" margin_of_error = 0.01 fraction_picking = 0.50 z_score = 2.58 # Corresponds to confidence level 99% diff --git a/bin/import_swc.py b/bin/import_swc.py index 59bc3084..4b984131 100755 --- a/bin/import_swc.py +++ b/bin/import_swc.py @@ -21,10 +21,9 @@ from multiprocessing.pool import ThreadPool import progressbar import sox - +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import validate_label_eng as validate_label -from coqui_stt_ctcdecoder import Alphabet SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_ARCHIVE = "SWC_{language}.tar" @@ -173,7 +172,6 @@ def in_alphabet(alphabet, c): return alphabet.CanEncode(c) if alphabet else True - ALPHABETS = {} @@ -202,8 +200,16 @@ def label_filter(label, language): dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else "" alphabet = get_alphabet(language) for c in label: - if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c): - c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") + if ( + CLI_ARGS.normalize + and c not in dont_normalize + and not in_alphabet(alphabet, c) + ): + c = ( + unicodedata.normalize("NFKD", c) + .encode("ascii", "ignore") + .decode("ascii", "ignore") + ) for sc in c: if not in_alphabet(alphabet, sc): return None, "illegal character" diff --git a/bin/import_ted.py b/bin/import_ted.py index f88a248f..cb63a4f9 100755 --- a/bin/import_ted.py +++ b/bin/import_ted.py @@ -7,11 +7,11 @@ from glob import glob from os import makedirs, path, remove, rmdir import pandas -from sox import Transformer -from tensorflow.python.platform import gfile - from coqui_stt_training.util.downloader import maybe_download from coqui_stt_training.util.stm import parse_stm_file +from sox import Transformer + +from tensorflow.python.platform import gfile def _download_and_preprocess_data(data_dir): diff --git a/bin/import_ts.py b/bin/import_ts.py index 0ce3fdf2..270133de 100755 --- a/bin/import_ts.py +++ b/bin/import_ts.py @@ -8,7 +8,6 @@ from multiprocessing import Pool import progressbar import sox - import unidecode from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import ( @@ -132,9 +131,15 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False): pool.close() pool.join() - with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% - with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% - with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% + with open( + target_csv_template.format("train"), "w", encoding="utf-8", newline="" + ) as train_csv_file: # 80% + with open( + target_csv_template.format("dev"), "w", encoding="utf-8", newline="" + ) as dev_csv_file: # 10% + with open( + target_csv_template.format("test"), "w", encoding="utf-8", newline="" + ) as test_csv_file: # 10% train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer.writeheader() dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) diff --git a/bin/import_tuda.py b/bin/import_tuda.py index 99ffd9f5..a8e2bf49 100755 --- a/bin/import_tuda.py +++ b/bin/import_tuda.py @@ -13,10 +13,9 @@ import xml.etree.ElementTree as ET from collections import Counter import progressbar - +from coqui_stt_ctcdecoder import Alphabet from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import validate_label_eng as validate_label -from coqui_stt_ctcdecoder import Alphabet TUDA_VERSION = "v2" TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION) @@ -55,7 +54,11 @@ def check_and_prepare_sentence(sentence): chars = [] for c in sentence: if CLI_ARGS.normalize and c not in "äöüß" and not in_alphabet(c): - c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") + c = ( + unicodedata.normalize("NFKD", c) + .encode("ascii", "ignore") + .decode("ascii", "ignore") + ) for sc in c: if not in_alphabet(c): return None @@ -118,7 +121,7 @@ def write_csvs(extracted): sentence = list(meta.iter("cleaned_sentence"))[0].text sentence = check_and_prepare_sentence(sentence) if sentence is None: - reasons['alphabet filter'] += 1 + reasons["alphabet filter"] += 1 continue for wav_name in wav_names: sample_counter += 1 diff --git a/bin/import_vctk.py b/bin/import_vctk.py index b2b85b6d..3b4c531d 100755 --- a/bin/import_vctk.py +++ b/bin/import_vctk.py @@ -10,7 +10,6 @@ from zipfile import ZipFile import librosa import progressbar - from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.importers import ( get_counter, diff --git a/bin/import_voxforge.py b/bin/import_voxforge.py index b01dca72..c0e293f8 100755 --- a/bin/import_voxforge.py +++ b/bin/import_voxforge.py @@ -13,9 +13,10 @@ from os import makedirs, path import pandas from bs4 import BeautifulSoup -from tensorflow.python.platform import gfile from coqui_stt_training.util.downloader import maybe_download +from tensorflow.python.platform import gfile + """The number of jobs to run in parallel""" NUM_PARALLEL = 8 diff --git a/bin/play.py b/bin/play.py index 59433e18..c80e2a02 100755 --- a/bin/play.py +++ b/bin/play.py @@ -4,14 +4,26 @@ Tool for playing (and augmenting) single samples or samples from Sample Database Use "python3 play.py -h" for help """ -import os -import sys -import random import argparse +import os +import random +import sys -from coqui_stt_training.util.audio import get_loadable_audio_type_from_extension, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV -from coqui_stt_training.util.sample_collections import SampleList, LabeledSample, samples_from_source -from coqui_stt_training.util.augmentations import parse_augmentations, apply_sample_augmentations, SampleAugmentation +from coqui_stt_training.util.audio import ( + AUDIO_TYPE_PCM, + AUDIO_TYPE_WAV, + get_loadable_audio_type_from_extension, +) +from coqui_stt_training.util.augmentations import ( + SampleAugmentation, + apply_sample_augmentations, + parse_augmentations, +) +from coqui_stt_training.util.sample_collections import ( + LabeledSample, + SampleList, + samples_from_source, +) def get_samples_in_play_order(): @@ -43,11 +55,13 @@ def play_collection(): if any(not isinstance(a, SampleAugmentation) for a in augmentations): print("Warning: Some of the augmentations cannot be simulated by this command.") samples = get_samples_in_play_order() - samples = apply_sample_augmentations(samples, - audio_type=AUDIO_TYPE_PCM, - augmentations=augmentations, - process_ahead=0, - clock=CLI_ARGS.clock) + samples = apply_sample_augmentations( + samples, + audio_type=AUDIO_TYPE_PCM, + augmentations=augmentations, + process_ahead=0, + clock=CLI_ARGS.clock, + ) for sample in samples: if not CLI_ARGS.quiet: print('Sample "{}"'.format(sample.sample_id), file=sys.stderr) @@ -57,10 +71,12 @@ def play_collection(): sample.change_audio_type(AUDIO_TYPE_WAV) sys.stdout.buffer.write(sample.audio.getvalue()) return - wave_obj = simpleaudio.WaveObject(sample.audio, - sample.audio_format.channels, - sample.audio_format.width, - sample.audio_format.rate) + wave_obj = simpleaudio.WaveObject( + sample.audio, + sample.audio_format.channels, + sample.audio_format.width, + sample.audio_format.rate, + ) play_obj = wave_obj.play() play_obj.wait_done() @@ -70,7 +86,9 @@ def handle_args(): description="Tool for playing (and augmenting) single samples or samples from Sample Databases (SDB files) " "and Coqui STT CSV files" ) - parser.add_argument("source", help="Sample DB, CSV or WAV file to play samples from") + parser.add_argument( + "source", help="Sample DB, CSV or WAV file to play samples from" + ) parser.add_argument( "--start", type=int, @@ -90,7 +108,7 @@ def handle_args(): ) parser.add_argument( "--augment", - action='append', + action="append", help="Add an augmentation operation", ) parser.add_argument( @@ -98,8 +116,8 @@ def handle_args(): type=float, default=0.5, help="Simulates clock value used for augmentations during training." - "Ranges from 0.0 (representing parameter start values) to" - "1.0 (representing parameter end values)", + "Ranges from 0.0 (representing parameter start values) to" + "1.0 (representing parameter end values)", ) parser.add_argument( "--pipe", @@ -120,7 +138,9 @@ if __name__ == "__main__": try: import simpleaudio except ModuleNotFoundError: - print('Unless using the --pipe flag, play.py requires Python package "simpleaudio" for playing samples') + print( + 'Unless using the --pipe flag, play.py requires Python package "simpleaudio" for playing samples' + ) sys.exit(1) try: play_collection() diff --git a/data/README.rst b/data/README.rst index 289146c9..6868c3b4 100644 --- a/data/README.rst +++ b/data/README.rst @@ -8,4 +8,3 @@ This directory contains language-specific data files. Most importantly, you will 2. A script used to generate a binary n-gram language model: ``data/lm/generate_lm.py``. For more information on how to build these resources from scratch, see the ``External scorer scripts`` section on `stt.readthedocs.io `_. - diff --git a/data/lm/generate_lm.py b/data/lm/generate_lm.py index 47941437..f26f257a 100644 --- a/data/lm/generate_lm.py +++ b/data/lm/generate_lm.py @@ -78,20 +78,20 @@ def build_lm(args, data_lower, vocab_str): print("\nCreating ARPA file ...") lm_path = os.path.join(args.output_dir, "lm.arpa") subargs = [ - os.path.join(args.kenlm_bins, "lmplz"), - "--order", - str(args.arpa_order), - "--temp_prefix", - args.output_dir, - "--memory", - args.max_arpa_memory, - "--text", - data_lower, - "--arpa", - lm_path, - "--prune", - *args.arpa_prune.split("|"), - ] + os.path.join(args.kenlm_bins, "lmplz"), + "--order", + str(args.arpa_order), + "--temp_prefix", + args.output_dir, + "--memory", + args.max_arpa_memory, + "--text", + data_lower, + "--arpa", + lm_path, + "--prune", + *args.arpa_prune.split("|"), + ] if args.discount_fallback: subargs += ["--discount_fallback"] subprocess.check_call(subargs) diff --git a/data/smoke_test/russian_sample_data/alphabet.ru b/data/smoke_test/russian_sample_data/alphabet.ru index 6dc0e3cc..262d5706 100644 --- a/data/smoke_test/russian_sample_data/alphabet.ru +++ b/data/smoke_test/russian_sample_data/alphabet.ru @@ -1,4 +1,4 @@ - + о е а diff --git a/data/smoke_test/russian_sample_data/ru.csv b/data/smoke_test/russian_sample_data/ru.csv index 18aeb63b..75766110 100644 --- a/data/smoke_test/russian_sample_data/ru.csv +++ b/data/smoke_test/russian_sample_data/ru.csv @@ -1,2 +1,2 @@ wav_filename,wav_filesize,transcript -ru.wav,0,бедняга ребят на его месте должен был быть я \ No newline at end of file +ru.wav,0,бедняга ребят на его месте должен был быть я diff --git a/data/smoke_test/vocab.pruned.bytes.txt b/data/smoke_test/vocab.pruned.bytes.txt index 4c4c80cb..4a77dda4 100644 --- a/data/smoke_test/vocab.pruned.bytes.txt +++ b/data/smoke_test/vocab.pruned.bytes.txt @@ -3537,4 +3537,4 @@ p r o t e c t e d t h a t ' s f o r m e r m e a n t -j o i n t \ No newline at end of file +j o i n t diff --git a/doc/AUGMENTATION.rst b/doc/AUGMENTATION.rst index 696d0a83..0b168f7d 100644 --- a/doc/AUGMENTATION.rst +++ b/doc/AUGMENTATION.rst @@ -5,7 +5,7 @@ Training Data Augmentation This document is an overview of the augmentation techniques available for training with STT. -Training data augmentations can help STT models better transcribe new speech at deployment time. The basic intuition behind data augmentation is the following: by distorting, modifying, or adding to your existing audio data, you can create a training set many times larger than what you started with. If you use a larger training data set to train as STT model, you force the model to learn more generalizable characteristics of speech, making `overfitting `_ more difficult. If you can't find a larger data set of speech, you can create one with data augmentation. +Training data augmentations can help STT models better transcribe new speech at deployment time. The basic intuition behind data augmentation is the following: by distorting, modifying, or adding to your existing audio data, you can create a training set many times larger than what you started with. If you use a larger training data set to train as STT model, you force the model to learn more generalizable characteristics of speech, making `overfitting `_ more difficult. If you can't find a larger data set of speech, you can create one with data augmentation. We have implemented a pre-processing pipeline with various augmentation techniques on audio data (i.e. raw ``PCM`` and spectrograms). diff --git a/doc/BUILDING_DotNet.rst b/doc/BUILDING_DotNet.rst index 6c0b67cf..a3755bfe 100644 --- a/doc/BUILDING_DotNet.rst +++ b/doc/BUILDING_DotNet.rst @@ -119,7 +119,7 @@ Building the native_client There's one last command to run before building, you need to run the `configure.py `_ inside ``tensorflow`` cloned directory. -At this point we are ready to start building the ``native_client``, go to ``tensorflow`` sub-directory, following our examples should be ``D:\cloned\STT\tensorflow``. +At this point we are ready to start building the ``native_client``, go to ``tensorflow`` sub-directory, following our examples should be ``D:\cloned\STT\tensorflow``. CPU ~~~ diff --git a/doc/CHECKPOINTING.rst b/doc/CHECKPOINTING.rst index 5c5665ec..ae7461bf 100644 --- a/doc/CHECKPOINTING.rst +++ b/doc/CHECKPOINTING.rst @@ -3,7 +3,7 @@ Checkpointing ============= -Checkpoints are representations of the parameters of a neural network. During training, model parameters are continually updated, and checkpoints allow graceful interruption of a training run without data loss. If you interrupt a training run for any reason, you can pick up where you left off by using the checkpoints as a starting place. This is the exact same logic behind :ref:`model fine-tuning `. +Checkpoints are representations of the parameters of a neural network. During training, model parameters are continually updated, and checkpoints allow graceful interruption of a training run without data loss. If you interrupt a training run for any reason, you can pick up where you left off by using the checkpoints as a starting place. This is the exact same logic behind :ref:`model fine-tuning `. Checkpointing occurs at a configurable time interval. Resuming from checkpoints happens automatically by re-starting training with the same ``--checkpoint_dir`` of the former run. Alternatively, you can specify more fine grained options with ``--load_checkpoint_dir`` and ``--save_checkpoint_dir``, which specify separate locations to use for loading and saving checkpoints respectively. diff --git a/doc/DEPLOYMENT.rst b/doc/DEPLOYMENT.rst index d4ad5c00..ab0f0176 100644 --- a/doc/DEPLOYMENT.rst +++ b/doc/DEPLOYMENT.rst @@ -134,7 +134,7 @@ The script ``taskcluster.py`` will download ``native_client.tar.xz`` (which incl Alternatively you may manually download the ``native_client.tar.xz`` from the `releases page `_. -Assuming you have :ref:`downloaded the pre-trained models `, you can use the client as such: +Assuming you have :ref:`downloaded the pre-trained models `, you can use the client as such: .. code-block:: bash diff --git a/doc/HotWordBoosting-Examples.rst b/doc/HotWordBoosting-Examples.rst index ec3ce4db..17cc53d9 100644 --- a/doc/HotWordBoosting-Examples.rst +++ b/doc/HotWordBoosting-Examples.rst @@ -1,12 +1,12 @@ Hot-word boosting API Usage example =================================== -With the 🐸STT 0.9 release a new API feature was introduced that allows boosting probability from the scorer of given words. It is exposed in all bindings (C, Python, JS, Java and .Net). +With the 🐸STT 0.9 release a new API feature was introduced that allows boosting probability from the scorer of given words. It is exposed in all bindings (C, Python, JS, Java and .Net). Currently, it provides three methods for the Model class: - ``AddHotWord(word, boost)`` -- ``EraseHotWord(word)`` +- ``EraseHotWord(word)`` - ``ClearHotWords()`` Exact API binding for the language you are using can be found in API Reference. @@ -14,7 +14,7 @@ Exact API binding for the language you are using can be found in API Reference. General usage ------------- -It is worth noting that boosting non-existent words in scorer (mostly proper nouns) or a word that share no phonetic prefix with other word in the input audio don't change the final transcription. Additionally, hot-word that has a space will not be taken into consideration, meaning that combination of words can not be boosted and each word must be added as hot-word separately. +It is worth noting that boosting non-existent words in scorer (mostly proper nouns) or a word that share no phonetic prefix with other word in the input audio don't change the final transcription. Additionally, hot-word that has a space will not be taken into consideration, meaning that combination of words can not be boosted and each word must be added as hot-word separately. Adjusting the boosting value ---------------------------- @@ -29,9 +29,9 @@ There is a user contributed script available on ``STT-examples`` repository for Positive value boosting ----------------------- -By adding a positive boost value to one of the words it is possible to increase the probability of the word occurence. This is particularly useful for detecting speech that is expected by the system. +By adding a positive boost value to one of the words it is possible to increase the probability of the word occurence. This is particularly useful for detecting speech that is expected by the system. -In the output, overextensive positive boost value (e.g. 250.0 but it does vary) may cause a word following the boosted hot-word to be split into separate letters. This problem is related to the scorer structure and currently only way to avoid it is to tune boost to a lower value. +In the output, overextensive positive boost value (e.g. 250.0 but it does vary) may cause a word following the boosted hot-word to be split into separate letters. This problem is related to the scorer structure and currently only way to avoid it is to tune boost to a lower value. Negative value boosting ----------------------- @@ -40,7 +40,7 @@ Respectively, applying negative boost value might cause the selected word to occ Previously mentioned problem where extensive boost value caused letter splitting doesn't arise for negative boost values. -Example +Example ------- To use hot-word boosting just add hot-words of your choice performing a speech-to-text operation with a ``Model``. You can also erase boosting of a chosen word or clear it for all hot-words. @@ -52,5 +52,5 @@ To use hot-word boosting just add hot-words of your choice performing a speech-t ds.addHotWord(word, boosting) ... print(ds.stt(audio)) - + Adding boost value to a word repeatedly or erasing hot-word without previously boosting it results in an error. diff --git a/doc/TRAINING_INTRO.rst b/doc/TRAINING_INTRO.rst index cdc525fe..f0312dfe 100644 --- a/doc/TRAINING_INTRO.rst +++ b/doc/TRAINING_INTRO.rst @@ -138,7 +138,7 @@ Data Format Audio data is expected to be stored as WAV, sampled at 16kHz, and mono-channel. There's no hard expectations for the length of individual audio files, but in our experience, training is most successful when WAV files range from 5 to 20 seconds in length. Your training data should match as closely as possible the kind of speech you expect at deployment. You can read more about the significant characteristics of speech with regard to STT :ref:`here `. -Text transcripts should be formatted exactly as the transcripts you expect your model to produce at deployment. If you want your model to produce capital letters, your transcripts should include capital letters. If you want your model to produce punctuation, your transcripts should include punctuation. Keep in mind that the more characters you include in your transcripts, the more difficult the task becomes for your model. STT models learn from experience, and if there's very few examples in the training data, the model will have a hard time learning rare characters (e.g. the "ï" in "naïve"). +Text transcripts should be formatted exactly as the transcripts you expect your model to produce at deployment. If you want your model to produce capital letters, your transcripts should include capital letters. If you want your model to produce punctuation, your transcripts should include punctuation. Keep in mind that the more characters you include in your transcripts, the more difficult the task becomes for your model. STT models learn from experience, and if there's very few examples in the training data, the model will have a hard time learning rare characters (e.g. the "ï" in "naïve"). CSV file format """"""""""""""" diff --git a/doc/conf.py b/doc/conf.py index 0feec64e..594c81ea 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -22,21 +22,27 @@ import os import sys -sys.path.insert(0, os.path.abspath('../')) +sys.path.insert(0, os.path.abspath("../")) -autodoc_mock_imports = ['stt'] +autodoc_mock_imports = ["stt"] # This is in fact only relevant on ReadTheDocs, but we want to run the same way # on our CI as in RTD to avoid regressions on RTD that we would not catch on CI import subprocess + parent = subprocess.check_output("cd ../ && pwd", shell=True).decode().strip() -os.environ["PATH"] = os.path.join(parent, 'node_modules', '.bin') + ':' + os.environ["PATH"] -subprocess.check_call('cd ../ && npm install typedoc@0.17.4 typescript@3.8.3 @types/node@13.9.x', shell=True) -subprocess.check_call('env', shell=True) -subprocess.check_call('which typedoc', shell=True) -subprocess.check_call('cd ../ && doxygen doc/doxygen-c.conf', shell=True) -subprocess.check_call('cd ../ && doxygen doc/doxygen-java.conf', shell=True) -subprocess.check_call('cd ../ && doxygen doc/doxygen-dotnet.conf', shell=True) +os.environ["PATH"] = ( + os.path.join(parent, "node_modules", ".bin") + ":" + os.environ["PATH"] +) +subprocess.check_call( + "cd ../ && npm install typedoc@0.17.4 typescript@3.8.3 @types/node@13.9.x", + shell=True, +) +subprocess.check_call("env", shell=True) +subprocess.check_call("which typedoc", shell=True) +subprocess.check_call("cd ../ && doxygen doc/doxygen-c.conf", shell=True) +subprocess.check_call("cd ../ && doxygen doc/doxygen-java.conf", shell=True) +subprocess.check_call("cd ../ && doxygen doc/doxygen-dotnet.conf", shell=True) # -- General configuration ------------------------------------------------ @@ -44,11 +50,11 @@ import semver # -- Project information ----------------------------------------------------- -project = u'Coqui STT' -copyright = '2021 Coqui GmbH, 2020 DeepSpeech authors, 2019-2020 Mozilla Corporation' -author = 'Coqui GmbH' +project = u"Coqui STT" +copyright = "2021 Coqui GmbH, 2020 DeepSpeech authors, 2019-2020 Mozilla Corporation" +author = "Coqui GmbH" -with open('../VERSION', 'r') as ver: +with open("../VERSION", "r") as ver: v = ver.read().strip() vv = semver.parse(v) @@ -56,7 +62,7 @@ vv = semver.parse(v) # |version| and |release|, also used in various other places throughout the # built documents. # The short X.Y version -version = '{}.{}'.format(vv['major'], vv['minor']) +version = "{}.{}".format(vv["major"], vv["minor"]) # The full version, including alpha/beta/rc tags release = v @@ -68,22 +74,22 @@ release = v # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.extlinks', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx_js', - 'sphinx_csharp', - 'breathe', - 'recommonmark', + "sphinx.ext.autodoc", + "sphinx.ext.extlinks", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx_js", + "sphinx_csharp", + "breathe", + "recommonmark", ] breathe_projects = { - "stt-c": "xml-c/", - "stt-java": "xml-java/", - "stt-dotnet": "xml-dotnet/", + "stt-c": "xml-c/", + "stt-java": "xml-java/", + "stt-dotnet": "xml-dotnet/", } js_source_path = "../native_client/javascript/index.ts" @@ -91,16 +97,16 @@ js_language = "typescript" jsdoc_config_path = "../native_client/javascript/tsconfig.json" # Add any paths that contain templates here, relative to this directory. -templates_path = ['.templates'] +templates_path = [".templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The main toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -112,10 +118,10 @@ language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store', 'node_modules', 'examples'] +exclude_patterns = [".build", "Thumbs.db", ".DS_Store", "node_modules", "examples"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -128,18 +134,18 @@ add_module_names = False # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'furo' +html_theme = "furo" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['.static'] +html_static_path = [".static"] # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'STTdoc' +htmlhelp_basename = "STTdoc" # -- Options for LaTeX output --------------------------------------------- @@ -148,15 +154,12 @@ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -166,8 +169,7 @@ latex_elements = { # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'STT.tex', u'Coqui STT Documentation', - u'Coqui GmbH', 'manual'), + (master_doc, "STT.tex", u"Coqui STT Documentation", u"Coqui GmbH", "manual"), ] @@ -175,10 +177,7 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'stt', u'Coqui STT Documentation', - [author], 1) -] +man_pages = [(master_doc, "stt", u"Coqui STT Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -187,16 +186,21 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'STT', u'Coqui STT Documentation', - author, 'STT', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "STT", + u"Coqui STT Documentation", + author, + "STT", + "One line description of project.", + "Miscellaneous", + ), ] - - # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} -extlinks = {'github': ('https://github.com/coqui-ai/STT/blob/v{}/%s'.format(release), - '%s')} +extlinks = { + "github": ("https://github.com/coqui-ai/STT/blob/v{}/%s".format(release), "%s") +} diff --git a/doc/index.rst b/doc/index.rst index 6a4f4011..204806f4 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -24,7 +24,7 @@ Coqui STT Quickstart: Deployment ^^^^^^^^^^^^^^^^^^^^^^ -The fastest way to deploy a pre-trained 🐸STT model is with `pip` with Python 3.5 or higher (*Note - only Linux supported at this time. We are working to get our normally supported packages back up and running.*): +The fastest way to deploy a pre-trained 🐸STT model is with `pip` with Python 3.5 or higher (*Note - only Linux supported at this time. We are working to get our normally supported packages back up and running.*): .. code-block:: bash diff --git a/doc/playbook/ENVIRONMENT.md b/doc/playbook/ENVIRONMENT.md index 72594126..9d1526a3 100644 --- a/doc/playbook/ENVIRONMENT.md +++ b/doc/playbook/ENVIRONMENT.md @@ -16,7 +16,7 @@ + [Testing the image by creating a container and running a script](#testing-the-image-by-creating-a-container-and-running-a-script) * [Setting up a bind mount to store persistent data](#setting-up-a-bind-mount-to-store-persistent-data) * [Extending the base `stt-train` Docker image for your needs](#extending-the-base--stt-train--docker-image-for-your-needs) - + This section of the Playbook assumes you are comfortable installing 🐸STT and using it with a pre-trained model, and that you are comfortable setting up a Python _virtual environment_. Here, we provide information on setting up a Docker environment for training your own speech recognition model using 🐸STT. We also cover dependencies Docker has for NVIDIA GPUs, so that you can use your GPU(s) for training a model. @@ -48,7 +48,7 @@ By default, your machine should already have GPU drivers installed. A good way t ``` $ nvidia-smi -Sat Jan 9 11:48:50 2021 +Sat Jan 9 11:48:50 2021 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.80.02 Driver Version: 450.80.02 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ @@ -195,7 +195,7 @@ This command assumes that `/bin/bash` will be invoked as the `root` user. This i When you run the above command, you should see the following prompt: ``` -________ _______________ +________ _______________ ___ __/__________________________________ ____/__ /________ __ __ / _ _ \_ __ \_ ___/ __ \_ ___/_ /_ __ /_ __ \_ | /| / / _ / / __/ / / /(__ )/ /_/ / / _ __/ _ / / /_/ /_ |/ |/ / diff --git a/doc/playbook/README.md b/doc/playbook/README.md index 017fd9ff..192a3c6a 100644 --- a/doc/playbook/README.md +++ b/doc/playbook/README.md @@ -28,7 +28,7 @@ If you are training a model that uses a different alphabet to English, for examp ## [Building your own scorer](SCORER.md) -Learn what the scorer does, and how you can go about building your own. +Learn what the scorer does, and how you can go about building your own. ## [Acoustic model and language model](AM_vs_LM.md) @@ -66,7 +66,7 @@ Here, we've linked to several resources that you may find helpful; they're liste * [Google's machine learning crash course](https://developers.google.com/machine-learning/crash-course/ml-intro) provides a gentle introduction to the main concepts of machine learning, including _gradient descent_, _learning rate_, _training, test and validation sets_ and _overfitting_. -* If machine learning is something that sparks your interest, then you may enjoy [the MIT Open Learning Library's Introduction to Machine Learning course](https://openlearninglibrary.mit.edu/courses/course-v1:MITx+6.036+1T2019/course/), a 13-week college-level course covering perceptrons, neural networks, support vector machines and convolutional neural networks. +* If machine learning is something that sparks your interest, then you may enjoy [the MIT Open Learning Library's Introduction to Machine Learning course](https://openlearninglibrary.mit.edu/courses/course-v1:MITx+6.036+1T2019/course/), a 13-week college-level course covering perceptrons, neural networks, support vector machines and convolutional neural networks. --- diff --git a/doc/playbook/TESTING.md b/doc/playbook/TESTING.md index 74730a13..a56e0367 100644 --- a/doc/playbook/TESTING.md +++ b/doc/playbook/TESTING.md @@ -23,7 +23,7 @@ When you invoked `train.py` in the [training](TRAINING.md) section, and trained ``` Testing model on stt-data/cv-corpus-6.1-2020-12-11/id/clips/test.csv -Test epoch | Steps: 1844 | Elapsed Time: 0:51:11 +Test epoch | Steps: 1844 | Elapsed Time: 0:51:11 Test on stt-data/cv-corpus-6.1-2020-12-11/id/clips/test.csv - WER: 1.000000, CER: 0.824103, loss: 104.989326 -------------------------------------------------------------------------------- Best WER: @@ -156,7 +156,7 @@ _Fine tuning_ and _transfer learning_ are two processes used to improve the accu For more information on [fine tuning in 🐸STT, please consult the documentation](https://stt.readthedocs.io/en/latest/TRAINING.html#fine-tuning-same-alphabet). -For more information on [transfer learning in 🐸STT, please consult the documentation](https://stt.readthedocs.io/en/latest/TRAINING.html#transfer-learning-new-alphabet). +For more information on [transfer learning in 🐸STT, please consult the documentation](https://stt.readthedocs.io/en/latest/TRAINING.html#transfer-learning-new-alphabet). --- diff --git a/evaluate.py b/evaluate.py index eca856b2..86fb06d0 100644 --- a/evaluate.py +++ b/evaluate.py @@ -2,11 +2,11 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -if __name__ == '__main__': +if __name__ == "__main__": try: from coqui_stt_training import evaluate as ds_evaluate except ImportError: - print('Training package is not installed. See training documentation.') + print("Training package is not installed. See training documentation.") raise ds_evaluate.run_script() diff --git a/evaluate_tflite.py b/evaluate_tflite.py index d8cff40f..d392a7ba 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -2,22 +2,22 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -import absl.app import argparse -import numpy as np -import wave import csv import os import sys +import wave +from functools import partial +from multiprocessing import JoinableQueue, Manager, Process, cpu_count -from stt import Model +import absl.app +import numpy as np from coqui_stt_training.util.evaluate_tools import calculate_and_print_report from coqui_stt_training.util.flags import create_flags -from functools import partial -from multiprocessing import JoinableQueue, Process, cpu_count, Manager -from six.moves import zip, range +from six.moves import range, zip +from stt import Model -r''' +r""" This module should be self-contained: - build libstt.so with TFLite: - bazel build [...] --define=runtime=tflite [...] //native_client:libstt.so @@ -27,10 +27,11 @@ This module should be self-contained: - pip install -r requirements_eval_tflite.txt Then run with a TFLite model, a scorer and a CSV test file -''' +""" + def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask): - os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask) + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_mask) ds = Model(model) ds.enableExternalScorer(scorer) @@ -38,29 +39,41 @@ def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask): try: msg = queue_in.get() - filename = msg['filename'] - fin = wave.open(filename, 'rb') + filename = msg["filename"] + fin = wave.open(filename, "rb") audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) fin.close() decoded = ds.stt(audio) - queue_out.put({'wav': filename, 'prediction': decoded, 'ground_truth': msg['transcript']}) + queue_out.put( + { + "wav": filename, + "prediction": decoded, + "ground_truth": msg["transcript"], + } + ) except FileNotFoundError as ex: - print('FileNotFoundError: ', ex) + print("FileNotFoundError: ", ex) - print(queue_out.qsize(), end='\r') # Update the current progress + print(queue_out.qsize(), end="\r") # Update the current progress queue_in.task_done() + def main(args, _): manager = Manager() - work_todo = JoinableQueue() # this is where we are going to store input data + work_todo = JoinableQueue() # this is where we are going to store input data work_done = manager.Queue() # this where we are gonna push them out processes = [] for i in range(args.proc): - worker_process = Process(target=tflite_worker, args=(args.model, args.scorer, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i)) - worker_process.start() # Launch reader() as a separate python process + worker_process = Process( + target=tflite_worker, + args=(args.model, args.scorer, work_todo, work_done, i), + daemon=True, + name="tflite_process_{}".format(i), + ) + worker_process.start() # Launch reader() as a separate python process processes.append(worker_process) print([x.name for x in processes]) @@ -71,56 +84,75 @@ def main(args, _): losses = [] wav_filenames = [] - with open(args.csv, 'r') as csvfile: + with open(args.csv, "r") as csvfile: csvreader = csv.DictReader(csvfile) count = 0 for row in csvreader: count += 1 # Relative paths are relative to the folder the CSV file is in - if not os.path.isabs(row['wav_filename']): - row['wav_filename'] = os.path.join(os.path.dirname(args.csv), row['wav_filename']) - work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']}) - wav_filenames.extend(row['wav_filename']) + if not os.path.isabs(row["wav_filename"]): + row["wav_filename"] = os.path.join( + os.path.dirname(args.csv), row["wav_filename"] + ) + work_todo.put( + {"filename": row["wav_filename"], "transcript": row["transcript"]} + ) + wav_filenames.extend(row["wav_filename"]) - print('Totally %d wav entries found in csv\n' % count) + print("Totally %d wav entries found in csv\n" % count) work_todo.join() - print('\nTotally %d wav file transcripted' % work_done.qsize()) + print("\nTotally %d wav file transcripted" % work_done.qsize()) while not work_done.empty(): msg = work_done.get() losses.append(0.0) - ground_truths.append(msg['ground_truth']) - predictions.append(msg['prediction']) - wavlist.append(msg['wav']) + ground_truths.append(msg["ground_truth"]) + predictions.append(msg["prediction"]) + wavlist.append(msg["wav"]) # Print test summary - _ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv) + _ = calculate_and_print_report( + wav_filenames, ground_truths, predictions, losses, args.csv + ) if args.dump: - with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: + with open(args.dump + ".txt", "w") as ftxt, open( + args.dump + ".out", "w" + ) as fout: for wav, txt, out in zip(wavlist, ground_truths, predictions): - ftxt.write('%s %s\n' % (wav, txt)) - fout.write('%s %s\n' % (wav, out)) - print('Reference texts dumped to %s.txt' % args.dump) - print('Transcription dumped to %s.out' % args.dump) + ftxt.write("%s %s\n" % (wav, txt)) + fout.write("%s %s\n" % (wav, out)) + print("Reference texts dumped to %s.txt" % args.dump) + print("Transcription dumped to %s.out" % args.dump) + def parse_args(): - parser = argparse.ArgumentParser(description='Computing TFLite accuracy') - parser.add_argument('--model', required=True, - help='Path to the model (protocol buffer binary file)') - parser.add_argument('--scorer', required=True, - help='Path to the external scorer file') - parser.add_argument('--csv', required=True, - help='Path to the CSV source file') - parser.add_argument('--proc', required=False, default=cpu_count(), type=int, - help='Number of processes to spawn, defaulting to number of CPUs') - parser.add_argument('--dump', required=False, - help='Path to dump the results as text file, with one line for each wav: "wav transcription".') + parser = argparse.ArgumentParser(description="Computing TFLite accuracy") + parser.add_argument( + "--model", required=True, help="Path to the model (protocol buffer binary file)" + ) + parser.add_argument( + "--scorer", required=True, help="Path to the external scorer file" + ) + parser.add_argument("--csv", required=True, help="Path to the CSV source file") + parser.add_argument( + "--proc", + required=False, + default=cpu_count(), + type=int, + help="Number of processes to spawn, defaulting to number of CPUs", + ) + parser.add_argument( + "--dump", + required=False, + help='Path to dump the results as text file, with one line for each wav: "wav transcription".', + ) args, unknown = parser.parse_known_args() # Reconstruct argv for absl.flags sys.argv = [sys.argv[0]] + unknown return args -if __name__ == '__main__': + +if __name__ == "__main__": create_flags() absl.app.run(partial(main, parse_args())) diff --git a/lm_optimizer.py b/lm_optimizer.py index 15fff5ed..669e4ba9 100644 --- a/lm_optimizer.py +++ b/lm_optimizer.py @@ -2,35 +2,39 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function +import sys + import absl.app import optuna -import sys -import tensorflow.compat.v1 as tfv1 - +from coqui_stt_ctcdecoder import Scorer from coqui_stt_training.evaluate import evaluate from coqui_stt_training.train import create_model from coqui_stt_training.util.config import Config, initialize_globals -from coqui_stt_training.util.flags import create_flags, FLAGS -from coqui_stt_training.util.logging import log_error from coqui_stt_training.util.evaluate_tools import wer_cer_batch -from coqui_stt_ctcdecoder import Scorer +from coqui_stt_training.util.flags import FLAGS, create_flags +from coqui_stt_training.util.logging import log_error + +import tensorflow.compat.v1 as tfv1 def character_based(): is_character_based = False if FLAGS.scorer_path: - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) + scorer = Scorer( + FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet + ) is_character_based = scorer.is_utf8_mode() return is_character_based -def objective(trial): - FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max) - FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max) - is_character_based = trial.study.user_attrs['is_character_based'] +def objective(trial): + FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max) + FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max) + + is_character_based = trial.study.user_attrs["is_character_based"] samples = [] - for step, test_file in enumerate(FLAGS.test_files.split(',')): + for step, test_file in enumerate(FLAGS.test_files.split(",")): tfv1.reset_default_graph() current_samples = evaluate([test_file], create_model) @@ -47,12 +51,15 @@ def objective(trial): wer, cer = wer_cer_batch(samples) return cer if is_character_based else wer + def main(_): initialize_globals() if not FLAGS.test_files: - log_error('You need to specify what files to use for evaluation via ' - 'the --test_files flag.') + log_error( + "You need to specify what files to use for evaluation via " + "the --test_files flag." + ) sys.exit(1) is_character_based = character_based() @@ -60,11 +67,15 @@ def main(_): study = optuna.create_study() study.set_user_attr("is_character_based", is_character_based) study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) - print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'], - study.best_params['lm_beta'], - study.best_value)) + print( + "Best params: lm_alpha={} and lm_beta={} with WER={}".format( + study.best_params["lm_alpha"], + study.best_params["lm_beta"], + study.best_value, + ) + ) -if __name__ == '__main__': +if __name__ == "__main__": create_flags() absl.app.run(main) diff --git a/native_client/CODINGSTYLE.md b/native_client/CODINGSTYLE.md index f0e4ec48..c6fa52f7 100644 --- a/native_client/CODINGSTYLE.md +++ b/native_client/CODINGSTYLE.md @@ -18,8 +18,8 @@ Variable naming File naming =========== -* Source code files should have a `.cc` prefix and headers a `.h` prefix, excluding - code important from elsewhere, which should follow local conventions, e.g. `.cpp` and `.h` +* Source code files should have a `.cc` prefix and headers a `.h` prefix, excluding + code important from elsewhere, which should follow local conventions, e.g. `.cpp` and `.h` in `ctcdecode/`. Doubts diff --git a/native_client/client.cc b/native_client/client.cc index 93afa555..1ac8af05 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -152,7 +152,7 @@ MetadataToJSON(Metadata* result) } } } - + out_string << "\n}\n"; return strdup(out_string.str().c_str()); diff --git a/native_client/ctcdecode/LICENSE.parlance b/native_client/ctcdecode/LICENSE.parlance index 1d113197..6efb31d4 100644 --- a/native_client/ctcdecode/LICENSE.parlance +++ b/native_client/ctcdecode/LICENSE.parlance @@ -20,4 +20,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index fc8f3255..93365c80 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -1,17 +1,18 @@ from __future__ import absolute_import, division, print_function -from . import swigwrapper # pylint: disable=import-self +from . import swigwrapper # pylint: disable=import-self # This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle # string encoding explicitly, here and throughout this file. -__version__ = swigwrapper.__version__.decode('utf-8') +__version__ = swigwrapper.__version__.decode("utf-8") # Hack: import error codes by matching on their names, as SWIG unfortunately # does not support binding enums to Python in a scoped manner yet. for symbol in dir(swigwrapper): - if symbol.startswith('STT_ERR_'): + if symbol.startswith("STT_ERR_"): globals()[symbol] = getattr(swigwrapper, symbol) + class Scorer(swigwrapper.Scorer): """Wrapper for Scorer. @@ -23,130 +24,140 @@ class Scorer(swigwrapper.Scorer): :alphabet: Alphabet :type scorer_path: basestring """ + def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None): super(Scorer, self).__init__() # Allow bare initialization if alphabet: - assert alpha is not None, 'alpha parameter is required' - assert beta is not None, 'beta parameter is required' - assert scorer_path, 'scorer_path parameter is required' + assert alpha is not None, "alpha parameter is required" + assert beta is not None, "beta parameter is required" + assert scorer_path, "scorer_path parameter is required" - err = self.init(scorer_path.encode('utf-8'), alphabet) + err = self.init(scorer_path.encode("utf-8"), alphabet) if err != 0: - raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err)) + raise ValueError( + "Scorer initialization failed with error code 0x{:X}".format(err) + ) self.reset_params(alpha, beta) class Alphabet(swigwrapper.Alphabet): """Convenience wrapper for Alphabet which calls init in the constructor""" + def __init__(self, config_path): super(Alphabet, self).__init__() - err = self.init(config_path.encode('utf-8')) + err = self.init(config_path.encode("utf-8")) if err != 0: - raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err)) + raise ValueError( + "Alphabet initialization failed with error code 0x{:X}".format(err) + ) def CanEncodeSingle(self, input): - ''' + """ Returns true if the single character/output class has a corresponding label in the alphabet. - ''' - return super(Alphabet, self).CanEncodeSingle(input.encode('utf-8')) + """ + return super(Alphabet, self).CanEncodeSingle(input.encode("utf-8")) def CanEncode(self, input): - ''' + """ Returns true if the entire string can be encoded into labels in this alphabet. - ''' - return super(Alphabet, self).CanEncode(input.encode('utf-8')) + """ + return super(Alphabet, self).CanEncode(input.encode("utf-8")) def EncodeSingle(self, input): - ''' + """ Encode a single character/output class into a label. Character must be in the alphabet, this method will assert that. Use `CanEncodeSingle` to test. - ''' - return super(Alphabet, self).EncodeSingle(input.encode('utf-8')) + """ + return super(Alphabet, self).EncodeSingle(input.encode("utf-8")) def Encode(self, input): - ''' + """ Encode a sequence of character/output classes into a sequence of labels. Characters are assumed to always take a single Unicode codepoint. Characters must be in the alphabet, this method will assert that. Use `CanEncode` and `CanEncodeSingle` to test. - ''' + """ # Convert SWIG's UnsignedIntVec to a Python list - res = super(Alphabet, self).Encode(input.encode('utf-8')) + res = super(Alphabet, self).Encode(input.encode("utf-8")) return [el for el in res] def DecodeSingle(self, input): res = super(Alphabet, self).DecodeSingle(input) - return res.decode('utf-8') + return res.decode("utf-8") def Decode(self, input): - '''Decode a sequence of labels into a string.''' + """Decode a sequence of labels into a string.""" res = super(Alphabet, self).Decode(input) - return res.decode('utf-8') + return res.decode("utf-8") class UTF8Alphabet(swigwrapper.UTF8Alphabet): """Convenience wrapper for Alphabet which calls init in the constructor""" + def __init__(self): super(UTF8Alphabet, self).__init__() - err = self.init(b'') + err = self.init(b"") if err != 0: - raise ValueError('UTF8Alphabet initialization failed with error code 0x{:X}'.format(err)) + raise ValueError( + "UTF8Alphabet initialization failed with error code 0x{:X}".format(err) + ) def CanEncodeSingle(self, input): - ''' + """ Returns true if the single character/output class has a corresponding label in the alphabet. - ''' - return super(UTF8Alphabet, self).CanEncodeSingle(input.encode('utf-8')) + """ + return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8")) def CanEncode(self, input): - ''' + """ Returns true if the entire string can be encoded into labels in this alphabet. - ''' - return super(UTF8Alphabet, self).CanEncode(input.encode('utf-8')) + """ + return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8")) def EncodeSingle(self, input): - ''' + """ Encode a single character/output class into a label. Character must be in the alphabet, this method will assert that. Use `CanEncodeSingle` to test. - ''' - return super(UTF8Alphabet, self).EncodeSingle(input.encode('utf-8')) + """ + return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8")) def Encode(self, input): - ''' + """ Encode a sequence of character/output classes into a sequence of labels. Characters are assumed to always take a single Unicode codepoint. Characters must be in the alphabet, this method will assert that. Use `CanEncode` and `CanEncodeSingle` to test. - ''' + """ # Convert SWIG's UnsignedIntVec to a Python list - res = super(UTF8Alphabet, self).Encode(input.encode('utf-8')) + res = super(UTF8Alphabet, self).Encode(input.encode("utf-8")) return [el for el in res] def DecodeSingle(self, input): res = super(UTF8Alphabet, self).DecodeSingle(input) - return res.decode('utf-8') + return res.decode("utf-8") def Decode(self, input): - '''Decode a sequence of labels into a string.''' + """Decode a sequence of labels into a string.""" res = super(UTF8Alphabet, self).Decode(input) - return res.decode('utf-8') + return res.decode("utf-8") - -def ctc_beam_search_decoder(probs_seq, - alphabet, - beam_size, - cutoff_prob=1.0, - cutoff_top_n=40, - scorer=None, - hot_words=dict(), - num_results=1): +def ctc_beam_search_decoder( + probs_seq, + alphabet, + beam_size, + cutoff_prob=1.0, + cutoff_top_n=40, + scorer=None, + hot_words=dict(), + num_results=1, +): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time @@ -175,22 +186,33 @@ def ctc_beam_search_decoder(probs_seq, :rtype: list """ beam_results = swigwrapper.ctc_beam_search_decoder( - probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n, - scorer, hot_words, num_results) - beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] + probs_seq, + alphabet, + beam_size, + cutoff_prob, + cutoff_top_n, + scorer, + hot_words, + num_results, + ) + beam_results = [ + (res.confidence, alphabet.Decode(res.tokens)) for res in beam_results + ] return beam_results -def ctc_beam_search_decoder_batch(probs_seq, - seq_lengths, - alphabet, - beam_size, - num_processes, - cutoff_prob=1.0, - cutoff_top_n=40, - scorer=None, - hot_words=dict(), - num_results=1): +def ctc_beam_search_decoder_batch( + probs_seq, + seq_lengths, + alphabet, + beam_size, + num_processes, + cutoff_prob=1.0, + cutoff_top_n=40, + scorer=None, + hot_words=dict(), + num_results=1, +): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list @@ -222,7 +244,18 @@ def ctc_beam_search_decoder_batch(probs_seq, results, in descending order of the confidence. :rtype: list """ - batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results) + batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch( + probs_seq, + seq_lengths, + alphabet, + beam_size, + num_processes, + cutoff_prob, + cutoff_top_n, + scorer, + hot_words, + num_results, + ) batch_beam_results = [ [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] for beam_results in batch_beam_results diff --git a/native_client/ctcdecode/build_archive.py b/native_client/ctcdecode/build_archive.py index f94d1adf..a4f13c1a 100644 --- a/native_client/ctcdecode/build_archive.py +++ b/native_client/ctcdecode/build_archive.py @@ -6,84 +6,95 @@ import os import shlex import subprocess import sys - from multiprocessing.dummy import Pool -if sys.platform.startswith('win'): - ARGS = ['/nologo', '/D KENLM_MAX_ORDER=6', '/EHsc', '/source-charset:utf-8'] - OPT_ARGS = ['/O2', '/MT', '/D NDEBUG'] - DBG_ARGS = ['/Od', '/MTd', '/Zi', '/U NDEBUG', '/D DEBUG'] - OPENFST_DIR = 'third_party/openfst-1.6.9-win' +if sys.platform.startswith("win"): + ARGS = ["/nologo", "/D KENLM_MAX_ORDER=6", "/EHsc", "/source-charset:utf-8"] + OPT_ARGS = ["/O2", "/MT", "/D NDEBUG"] + DBG_ARGS = ["/Od", "/MTd", "/Zi", "/U NDEBUG", "/D DEBUG"] + OPENFST_DIR = "third_party/openfst-1.6.9-win" else: - ARGS = ['-fPIC', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-Wno-unused-local-typedefs', '-Wno-sign-compare'] - OPT_ARGS = ['-O3', '-DNDEBUG'] - DBG_ARGS = ['-O0', '-g', '-UNDEBUG', '-DDEBUG'] - OPENFST_DIR = 'third_party/openfst-1.6.7' - + ARGS = [ + "-fPIC", + "-DKENLM_MAX_ORDER=6", + "-std=c++11", + "-Wno-unused-local-typedefs", + "-Wno-sign-compare", + ] + OPT_ARGS = ["-O3", "-DNDEBUG"] + DBG_ARGS = ["-O0", "-g", "-UNDEBUG", "-DDEBUG"] + OPENFST_DIR = "third_party/openfst-1.6.7" INCLUDES = [ - '..', - '../kenlm', - OPENFST_DIR + '/src/include', - 'third_party/ThreadPool', - 'third_party/object_pool' + "..", + "../kenlm", + OPENFST_DIR + "/src/include", + "third_party/ThreadPool", + "third_party/object_pool", ] -KENLM_FILES = (glob.glob('../kenlm/util/*.cc') - + glob.glob('../kenlm/lm/*.cc') - + glob.glob('../kenlm/util/double-conversion/*.cc')) +KENLM_FILES = ( + glob.glob("../kenlm/util/*.cc") + + glob.glob("../kenlm/lm/*.cc") + + glob.glob("../kenlm/util/double-conversion/*.cc") +) -KENLM_FILES += glob.glob(OPENFST_DIR + '/src/lib/*.cc') +KENLM_FILES += glob.glob(OPENFST_DIR + "/src/lib/*.cc") KENLM_FILES = [ - fn for fn in KENLM_FILES - if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( - 'unittest.cc')) + fn + for fn in KENLM_FILES + if not ( + fn.endswith("main.cc") or fn.endswith("test.cc") or fn.endswith("unittest.cc") + ) ] CTC_DECODER_FILES = [ - 'ctc_beam_search_decoder.cpp', - 'scorer.cpp', - 'path_trie.cpp', - 'decoder_utils.cpp', - 'workspace_status.cc', - '../alphabet.cc', + "ctc_beam_search_decoder.cpp", + "scorer.cpp", + "path_trie.cpp", + "decoder_utils.cpp", + "workspace_status.cc", + "../alphabet.cc", ] -def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1): - compiler = os.environ.get('CXX', 'g++') - if sys.platform.startswith('win'): + +def build_archive( + srcs=[], out_name="", build_dir="temp_build/temp_build", debug=False, num_parallel=1 +): + compiler = os.environ.get("CXX", "g++") + if sys.platform.startswith("win"): compiler = '"{}"'.format(compiler) - ar = os.environ.get('AR', 'ar') - libexe = os.environ.get('LIBEXE', 'lib.exe') - libtool = os.environ.get('LIBTOOL', 'libtool') - cflags = os.environ.get('CFLAGS', '') + os.environ.get('CXXFLAGS', '') + ar = os.environ.get("AR", "ar") + libexe = os.environ.get("LIBEXE", "lib.exe") + libtool = os.environ.get("LIBTOOL", "libtool") + cflags = os.environ.get("CFLAGS", "") + os.environ.get("CXXFLAGS", "") args = ARGS + (DBG_ARGS if debug else OPT_ARGS) for file in srcs: - outfile = os.path.join(build_dir, os.path.splitext(file)[0] + '.o') + outfile = os.path.join(build_dir, os.path.splitext(file)[0] + ".o") outdir = os.path.dirname(outfile) if not os.path.exists(outdir): - print('mkdir', outdir) + print("mkdir", outdir) os.makedirs(outdir) def build_one(file): - outfile = os.path.join(build_dir, os.path.splitext(file)[0] + '.o') + outfile = os.path.join(build_dir, os.path.splitext(file)[0] + ".o") if os.path.exists(outfile): return - if sys.platform.startswith('win'): - file = '"{}"'.format(file.replace('\\', '/')) - output = '/Fo"{}"'.format(outfile.replace('\\', '/')) + if sys.platform.startswith("win"): + file = '"{}"'.format(file.replace("\\", "/")) + output = '/Fo"{}"'.format(outfile.replace("\\", "/")) else: - output = '-o ' + outfile + output = "-o " + outfile - cmd = '{cc} -c {cflags} {args} {includes} {infile} {output}'.format( + cmd = "{cc} -c {cflags} {args} {includes} {infile} {output}".format( cc=compiler, cflags=cflags, - args=' '.join(args), - includes=' '.join('-I' + i for i in INCLUDES), + args=" ".join(args), + includes=" ".join("-I" + i for i in INCLUDES), infile=file, output=output, ) @@ -94,30 +105,28 @@ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug pool = Pool(num_parallel) obj_files = list(pool.imap_unordered(build_one, srcs)) - if sys.platform.startswith('darwin'): - cmd = '{libtool} -static -o {outfile} {infiles}'.format( + if sys.platform.startswith("darwin"): + cmd = "{libtool} -static -o {outfile} {infiles}".format( libtool=libtool, outfile=out_name, - infiles=' '.join(obj_files), + infiles=" ".join(obj_files), ) print(cmd) subprocess.check_call(shlex.split(cmd)) - elif sys.platform.startswith('win'): + elif sys.platform.startswith("win"): cmd = '"{libexe}" /OUT:"{outfile}" {infiles} /MACHINE:X64 /NOLOGO'.format( - libexe=libexe, - outfile=out_name, - infiles=' '.join(obj_files)) - cmd = cmd.replace('\\', '/') + libexe=libexe, outfile=out_name, infiles=" ".join(obj_files) + ) + cmd = cmd.replace("\\", "/") print(cmd) subprocess.check_call(shlex.split(cmd)) else: - cmd = '{ar} rcs {outfile} {infiles}'.format( - ar=ar, - outfile=out_name, - infiles=' '.join(obj_files) + cmd = "{ar} rcs {outfile} {infiles}".format( + ar=ar, outfile=out_name, infiles=" ".join(obj_files) ) print(cmd) subprocess.check_call(shlex.split(cmd)) -if __name__ == '__main__': + +if __name__ == "__main__": build_common() diff --git a/native_client/ctcdecode/decoder_utils.cpp b/native_client/ctcdecode/decoder_utils.cpp index bb3e1c77..2f67034e 100644 --- a/native_client/ctcdecode/decoder_utils.cpp +++ b/native_client/ctcdecode/decoder_utils.cpp @@ -161,4 +161,4 @@ bool add_word_to_dictionary( add_word_to_fst(int_word, dictionary); return true; // return with successful adding -} \ No newline at end of file +} diff --git a/native_client/ctcdecode/numpy.i b/native_client/ctcdecode/numpy.i index 36bb55c9..72d5824d 100644 --- a/native_client/ctcdecode/numpy.i +++ b/native_client/ctcdecode/numpy.i @@ -545,7 +545,7 @@ const npy_intp *dims = array_dimensions(ary); for (i=0; i < nd; ++i) n_non_one += (dims[i] != 1) ? 1 : 0; - if (n_non_one > 1) + if (n_non_one > 1) array_clearflags(ary,NPY_ARRAY_CARRAY); array_enableflags(ary,NPY_ARRAY_FARRAY); /* Recompute the strides */ diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h index 93a09437..255c1897 100644 --- a/native_client/ctcdecode/path_trie.h +++ b/native_client/ctcdecode/path_trie.h @@ -93,8 +93,8 @@ public: unsigned int character; TimestepTreeNode* timesteps = nullptr; - // timestep temporary storage for each decoding step. - TimestepTreeNode* previous_timesteps = nullptr; + // timestep temporary storage for each decoding step. + TimestepTreeNode* previous_timesteps = nullptr; unsigned int new_timestep; PathTrie* parent; diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index e5c6c359..eb059fdf 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -1,10 +1,10 @@ #ifdef _MSC_VER #include #include - #include + #include #define R_OK 4 /* Read permission. */ - #define W_OK 2 /* Write permission. */ + #define W_OK 2 /* Write permission. */ #define F_OK 0 /* Existence. */ #define access _access diff --git a/native_client/ctcdecode/setup.cfg b/native_client/ctcdecode/setup.cfg index ebfde62b..07e10d80 100644 --- a/native_client/ctcdecode/setup.cfg +++ b/native_client/ctcdecode/setup.cfg @@ -13,4 +13,3 @@ bdist-dir=temp_build/temp_build [install_lib] build-dir=temp_build/temp_build - diff --git a/native_client/ctcdecode/setup.py b/native_client/ctcdecode/setup.py index 3f0c778d..6b987b7b 100644 --- a/native_client/ctcdecode/setup.py +++ b/native_client/ctcdecode/setup.py @@ -1,95 +1,105 @@ #!/usr/bin/env python from __future__ import absolute_import, division, print_function -from distutils.command.build import build -from setuptools import setup, Extension, distutils - import argparse import multiprocessing.pool import os import platform import sys +from distutils.command.build import build from build_archive import * +from setuptools import Extension, distutils, setup try: import numpy + try: numpy_include = numpy.get_include() except AttributeError: numpy_include = numpy.get_numpy_include() except ImportError: - numpy_include = '' - assert 'NUMPY_INCLUDE' in os.environ + numpy_include = "" + assert "NUMPY_INCLUDE" in os.environ -numpy_include = os.getenv('NUMPY_INCLUDE', numpy_include) -numpy_min_ver = os.getenv('NUMPY_DEP_VERSION', '') +numpy_include = os.getenv("NUMPY_INCLUDE", numpy_include) +numpy_min_ver = os.getenv("NUMPY_DEP_VERSION", "") parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_processes", default=1, type=int, - help="Number of cpu processes to build package. (default: %(default)d)") + help="Number of cpu processes to build package. (default: %(default)d)", +) known_args, unknown_args = parser.parse_known_args() -debug = '--debug' in unknown_args +debug = "--debug" in unknown_args # reconstruct sys.argv to pass to setup below sys.argv = [sys.argv[0]] + unknown_args + def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() + def maybe_rebuild(srcs, out_name, build_dir): if not os.path.exists(out_name): if not os.path.exists(build_dir): os.makedirs(build_dir) - build_archive(srcs=srcs, - out_name=out_name, - build_dir=build_dir, - num_parallel=known_args.num_processes, - debug=debug) + build_archive( + srcs=srcs, + out_name=out_name, + build_dir=build_dir, + num_parallel=known_args.num_processes, + debug=debug, + ) -project_version = read('../../training/coqui_stt_training/VERSION').strip() -build_dir = 'temp_build/temp_build' +project_version = read("../../training/coqui_stt_training/VERSION").strip() -if sys.platform.startswith('win'): - archive_ext = 'lib' +build_dir = "temp_build/temp_build" + +if sys.platform.startswith("win"): + archive_ext = "lib" else: - archive_ext = 'a' + archive_ext = "a" -third_party_build = 'third_party.{}'.format(archive_ext) -ctc_decoder_build = 'first_party.{}'.format(archive_ext) +third_party_build = "third_party.{}".format(archive_ext) +ctc_decoder_build = "first_party.{}".format(archive_ext) maybe_rebuild(KENLM_FILES, third_party_build, build_dir) maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir) decoder_module = Extension( - name='coqui_stt_ctcdecoder._swigwrapper', - sources=['swigwrapper.i'], - swig_opts=['-c++', '-extranative'], - language='c++', + name="coqui_stt_ctcdecoder._swigwrapper", + sources=["swigwrapper.i"], + swig_opts=["-c++", "-extranative"], + language="c++", include_dirs=INCLUDES + [numpy_include], extra_compile_args=ARGS + (DBG_ARGS if debug else OPT_ARGS), extra_link_args=[ctc_decoder_build, third_party_build], ) + class BuildExtFirst(build): - sub_commands = [('build_ext', build.has_ext_modules), - ('build_py', build.has_pure_modules), - ('build_clib', build.has_c_libraries), - ('build_scripts', build.has_scripts)] + sub_commands = [ + ("build_ext", build.has_ext_modules), + ("build_py", build.has_pure_modules), + ("build_clib", build.has_c_libraries), + ("build_scripts", build.has_scripts), + ] + setup( - name='coqui_stt_ctcdecoder', + name="coqui_stt_ctcdecoder", version=project_version, description="""DS CTC decoder""", - cmdclass = {'build': BuildExtFirst}, + cmdclass={"build": BuildExtFirst}, ext_modules=[decoder_module], - package_dir = {'coqui_stt_ctcdecoder': '.'}, - py_modules=['coqui_stt_ctcdecoder', 'coqui_stt_ctcdecoder.swigwrapper'], - install_requires = ['numpy%s' % numpy_min_ver], + package_dir={"coqui_stt_ctcdecoder": "."}, + py_modules=["coqui_stt_ctcdecoder", "coqui_stt_ctcdecoder.swigwrapper"], + install_requires=["numpy%s" % numpy_min_ver], ) diff --git a/native_client/dotnet/.gitignore b/native_client/dotnet/.gitignore index 3e759b75..13a9b2ed 100644 --- a/native_client/dotnet/.gitignore +++ b/native_client/dotnet/.gitignore @@ -221,7 +221,7 @@ ClientBin/ *.publishsettings orleans.codegen.cs -# Including strong name files can present a security risk +# Including strong name files can present a security risk # (https://github.com/github/gitignore/pull/2483#issue-259490424) #*.snk @@ -317,7 +317,7 @@ __pycache__/ # OpenCover UI analysis results OpenCover/ -# Azure Stream Analytics local run output +# Azure Stream Analytics local run output ASALocalRun/ # MSBuild Binary and Structured Log @@ -326,5 +326,5 @@ ASALocalRun/ # NVidia Nsight GPU debugger configuration file *.nvuser -# MFractors (Xamarin productivity tool) working folder +# MFractors (Xamarin productivity tool) working folder .mfractor/ diff --git a/native_client/dotnet/STTClient/Models/CandidateTranscript.cs b/native_client/dotnet/STTClient/Models/CandidateTranscript.cs index f158e2c2..d67f039c 100644 --- a/native_client/dotnet/STTClient/Models/CandidateTranscript.cs +++ b/native_client/dotnet/STTClient/Models/CandidateTranscript.cs @@ -14,4 +14,4 @@ /// public TokenMetadata[] Tokens { get; set; } } -} \ No newline at end of file +} diff --git a/native_client/dotnet/STTClient/Models/Metadata.cs b/native_client/dotnet/STTClient/Models/Metadata.cs index 537a22e8..de0334cd 100644 --- a/native_client/dotnet/STTClient/Models/Metadata.cs +++ b/native_client/dotnet/STTClient/Models/Metadata.cs @@ -10,4 +10,4 @@ /// public CandidateTranscript[] Transcripts { get; set; } } -} \ No newline at end of file +} diff --git a/native_client/dotnet/STTClient/Models/TokenMetadata.cs b/native_client/dotnet/STTClient/Models/TokenMetadata.cs index c5ef94d8..b13a83cb 100644 --- a/native_client/dotnet/STTClient/Models/TokenMetadata.cs +++ b/native_client/dotnet/STTClient/Models/TokenMetadata.cs @@ -18,4 +18,4 @@ /// public float StartTime; } -} \ No newline at end of file +} diff --git a/native_client/dotnet/STTClient/STTClient.csproj b/native_client/dotnet/STTClient/STTClient.csproj index 33a94115..9df1250c 100644 --- a/native_client/dotnet/STTClient/STTClient.csproj +++ b/native_client/dotnet/STTClient/STTClient.csproj @@ -35,7 +35,7 @@ - + $(DefineConstants);NO_HTTPS diff --git a/native_client/dotnet/STTConsole/App.config b/native_client/dotnet/STTConsole/App.config index b50c74f3..43a7a8ec 100644 --- a/native_client/dotnet/STTConsole/App.config +++ b/native_client/dotnet/STTConsole/App.config @@ -1,6 +1,6 @@  - + - \ No newline at end of file + diff --git a/native_client/dotnet/STTConsole/STTConsole.csproj b/native_client/dotnet/STTConsole/STTConsole.csproj index 54e11eb0..3aa52fb9 100644 --- a/native_client/dotnet/STTConsole/STTConsole.csproj +++ b/native_client/dotnet/STTConsole/STTConsole.csproj @@ -67,4 +67,4 @@ - \ No newline at end of file + diff --git a/native_client/dotnet/STTConsole/packages.config b/native_client/dotnet/STTConsole/packages.config index 17c1f326..98dec3bb 100644 --- a/native_client/dotnet/STTConsole/packages.config +++ b/native_client/dotnet/STTConsole/packages.config @@ -1,4 +1,4 @@  - \ No newline at end of file + diff --git a/native_client/dotnet/STTWPF/.gitignore b/native_client/dotnet/STTWPF/.gitignore index 3e759b75..13a9b2ed 100644 --- a/native_client/dotnet/STTWPF/.gitignore +++ b/native_client/dotnet/STTWPF/.gitignore @@ -221,7 +221,7 @@ ClientBin/ *.publishsettings orleans.codegen.cs -# Including strong name files can present a security risk +# Including strong name files can present a security risk # (https://github.com/github/gitignore/pull/2483#issue-259490424) #*.snk @@ -317,7 +317,7 @@ __pycache__/ # OpenCover UI analysis results OpenCover/ -# Azure Stream Analytics local run output +# Azure Stream Analytics local run output ASALocalRun/ # MSBuild Binary and Structured Log @@ -326,5 +326,5 @@ ASALocalRun/ # NVidia Nsight GPU debugger configuration file *.nvuser -# MFractors (Xamarin productivity tool) working folder +# MFractors (Xamarin productivity tool) working folder .mfractor/ diff --git a/native_client/dotnet/STTWPF/App.config b/native_client/dotnet/STTWPF/App.config index b50c74f3..43a7a8ec 100644 --- a/native_client/dotnet/STTWPF/App.config +++ b/native_client/dotnet/STTWPF/App.config @@ -1,6 +1,6 @@  - + - \ No newline at end of file + diff --git a/native_client/dotnet/STTWPF/Properties/Resources.Designer.cs b/native_client/dotnet/STTWPF/Properties/Resources.Designer.cs index 2478decd..edcd793d 100644 --- a/native_client/dotnet/STTWPF/Properties/Resources.Designer.cs +++ b/native_client/dotnet/STTWPF/Properties/Resources.Designer.cs @@ -10,8 +10,8 @@ namespace STT.WPF.Properties { using System; - - + + /// /// A strongly-typed resource class, for looking up localized strings, etc. /// @@ -23,15 +23,15 @@ namespace STT.WPF.Properties { [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] internal class Resources { - + private static global::System.Resources.ResourceManager resourceMan; - + private static global::System.Globalization.CultureInfo resourceCulture; - + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] internal Resources() { } - + /// /// Returns the cached ResourceManager instance used by this class. /// @@ -45,7 +45,7 @@ namespace STT.WPF.Properties { return resourceMan; } } - + /// /// Overrides the current thread's CurrentUICulture property for all /// resource lookups using this strongly typed resource class. diff --git a/native_client/dotnet/STTWPF/Properties/Resources.resx b/native_client/dotnet/STTWPF/Properties/Resources.resx index af7dbebb..ea9cbcdb 100644 --- a/native_client/dotnet/STTWPF/Properties/Resources.resx +++ b/native_client/dotnet/STTWPF/Properties/Resources.resx @@ -1,17 +1,17 @@  - @@ -114,4 +114,4 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - \ No newline at end of file + diff --git a/native_client/dotnet/STTWPF/Properties/Settings.Designer.cs b/native_client/dotnet/STTWPF/Properties/Settings.Designer.cs index de63d157..6b67445f 100644 --- a/native_client/dotnet/STTWPF/Properties/Settings.Designer.cs +++ b/native_client/dotnet/STTWPF/Properties/Settings.Designer.cs @@ -9,14 +9,14 @@ //------------------------------------------------------------------------------ namespace STT.WPF.Properties { - - + + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.Editors.SettingsDesigner.SettingsSingleFileGenerator", "15.9.0.0")] internal sealed partial class Settings : global::System.Configuration.ApplicationSettingsBase { - + private static Settings defaultInstance = ((Settings)(global::System.Configuration.ApplicationSettingsBase.Synchronized(new Settings()))); - + public static Settings Default { get { return defaultInstance; diff --git a/native_client/dotnet/STTWPF/Properties/Settings.settings b/native_client/dotnet/STTWPF/Properties/Settings.settings index 033d7a5e..c2dbd5ca 100644 --- a/native_client/dotnet/STTWPF/Properties/Settings.settings +++ b/native_client/dotnet/STTWPF/Properties/Settings.settings @@ -4,4 +4,4 @@ - \ No newline at end of file + diff --git a/native_client/dotnet/STTWPF/ViewModels/MainWindowViewModel.cs b/native_client/dotnet/STTWPF/ViewModels/MainWindowViewModel.cs index 0ed4822b..309fbff9 100644 --- a/native_client/dotnet/STTWPF/ViewModels/MainWindowViewModel.cs +++ b/native_client/dotnet/STTWPF/ViewModels/MainWindowViewModel.cs @@ -131,7 +131,7 @@ namespace STT.WPF.ViewModels public MMDevice SelectedDevice { get => _selectedDevice; - set => SetProperty(ref _selectedDevice, value, + set => SetProperty(ref _selectedDevice, value, onChanged: UpdateSelectedDevice); } @@ -255,7 +255,7 @@ namespace STT.WPF.ViewModels private void LoadAvailableCaptureDevices() { AvailableRecordDevices = new ObservableCollection( - MMDeviceEnumerator.EnumerateDevices(DataFlow.All, DeviceState.Active)); //we get only enabled devices + MMDeviceEnumerator.EnumerateDevices(DataFlow.All, DeviceState.Active)); //we get only enabled devices EnableStartRecord = true; if (AvailableRecordDevices?.Count != 0) SelectedDevice = AvailableRecordDevices[0]; @@ -282,14 +282,14 @@ namespace STT.WPF.ViewModels .ToWaveSource(16); //bits per sample _convertedSource = _convertedSource.ToMono(); - } + } } private void Capture_DataAvailable(object sender, DataAvailableEventArgs e) { //read data from the converedSource //important: don't use the e.Data here - //the e.Data contains the raw data provided by the + //the e.Data contains the raw data provided by the //soundInSource which won't have the STT required audio format byte[] buffer = new byte[_convertedSource.WaveFormat.BytesPerSecond / 2]; @@ -319,7 +319,7 @@ namespace STT.WPF.ViewModels } } } - + /// /// Enables the external scorer. /// @@ -422,4 +422,4 @@ namespace STT.WPF.ViewModels } } } -} \ No newline at end of file +} diff --git a/native_client/dotnet/STTWPF/packages.config b/native_client/dotnet/STTWPF/packages.config index ececb977..0744ebce 100644 --- a/native_client/dotnet/STTWPF/packages.config +++ b/native_client/dotnet/STTWPF/packages.config @@ -6,4 +6,4 @@ - \ No newline at end of file + diff --git a/native_client/dotnet/nupkg/build/STT.targets b/native_client/dotnet/nupkg/build/STT.targets index deebff48..12bdcb3d 100644 --- a/native_client/dotnet/nupkg/build/STT.targets +++ b/native_client/dotnet/nupkg/build/STT.targets @@ -6,4 +6,4 @@ PreserveNewest - \ No newline at end of file + diff --git a/native_client/getopt_win.h b/native_client/getopt_win.h index 0cb88895..6bb908de 100644 --- a/native_client/getopt_win.h +++ b/native_client/getopt_win.h @@ -3,9 +3,9 @@ * DISCLAIMER * This file is part of the mingw-w64 runtime package. * - * The mingw-w64 runtime package and its code is distributed in the hope that it - * will be useful but WITHOUT ANY WARRANTY. ALL WARRANTIES, EXPRESSED OR - * IMPLIED ARE HEREBY DISCLAIMED. This includes but is not limited to + * The mingw-w64 runtime package and its code is distributed in the hope that it + * will be useful but WITHOUT ANY WARRANTY. ALL WARRANTIES, EXPRESSED OR + * IMPLIED ARE HEREBY DISCLAIMED. This includes but is not limited to * warranties of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. */ /* diff --git a/native_client/java/.idea/codeStyles/Project.xml b/native_client/java/.idea/codeStyles/Project.xml index 30aa626c..f898aa8f 100644 --- a/native_client/java/.idea/codeStyles/Project.xml +++ b/native_client/java/.idea/codeStyles/Project.xml @@ -26,4 +26,4 @@ - \ No newline at end of file + diff --git a/native_client/java/.idea/gradle.xml b/native_client/java/.idea/gradle.xml index 6e2e5ce6..1d5dd329 100644 --- a/native_client/java/.idea/gradle.xml +++ b/native_client/java/.idea/gradle.xml @@ -16,4 +16,4 @@ - \ No newline at end of file + diff --git a/native_client/java/.idea/misc.xml b/native_client/java/.idea/misc.xml index b0c7b20c..1a0fbc34 100644 --- a/native_client/java/.idea/misc.xml +++ b/native_client/java/.idea/misc.xml @@ -35,4 +35,4 @@ - \ No newline at end of file + diff --git a/native_client/java/.idea/runConfigurations.xml b/native_client/java/.idea/runConfigurations.xml index 7f68460d..9b6e38d6 100644 --- a/native_client/java/.idea/runConfigurations.xml +++ b/native_client/java/.idea/runConfigurations.xml @@ -9,4 +9,4 @@ - \ No newline at end of file + diff --git a/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml index eca70cfe..6b78462d 100644 --- a/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml +++ b/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml index eca70cfe..6b78462d 100644 --- a/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml +++ b/native_client/java/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/native_client/java/app/src/test/java/ai/coqui/sttexampleapp/ExampleUnitTest.java b/native_client/java/app/src/test/java/ai/coqui/sttexampleapp/ExampleUnitTest.java index be5d2766..a2b67ba8 100644 --- a/native_client/java/app/src/test/java/ai/coqui/sttexampleapp/ExampleUnitTest.java +++ b/native_client/java/app/src/test/java/ai/coqui/sttexampleapp/ExampleUnitTest.java @@ -14,4 +14,4 @@ public class ExampleUnitTest { public void addition_isCorrect() { assertEquals(4, 2 + 2); } -} \ No newline at end of file +} diff --git a/native_client/java/gradle.properties b/native_client/java/gradle.properties index 82618cec..743d692c 100644 --- a/native_client/java/gradle.properties +++ b/native_client/java/gradle.properties @@ -11,5 +11,3 @@ org.gradle.jvmargs=-Xmx1536m # This option should only be used with decoupled projects. More details, visit # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects # org.gradle.parallel=true - - diff --git a/native_client/java/jni/stt.i b/native_client/java/jni/stt.i index f3b2d134..9c194059 100644 --- a/native_client/java/jni/stt.i +++ b/native_client/java/jni/stt.i @@ -20,7 +20,7 @@ %extend struct CandidateTranscript { /** * Retrieve one TokenMetadata element - * + * * @param i Array index of the TokenMetadata to get * * @return The TokenMetadata requested or null @@ -33,7 +33,7 @@ %extend struct Metadata { /** * Retrieve one CandidateTranscript element - * + * * @param i Array index of the CandidateTranscript to get * * @return The CandidateTranscript requested or null diff --git a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/CandidateTranscript.java b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/CandidateTranscript.java index d04d4e12..44463f1d 100644 --- a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/CandidateTranscript.java +++ b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/CandidateTranscript.java @@ -36,7 +36,7 @@ public class CandidateTranscript { } /** - * Size of the tokens array + * Size of the tokens array */ public long getNumTokens() { return implJNI.CandidateTranscript_NumTokens_get(swigCPtr, this); diff --git a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/Metadata.java b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/Metadata.java index 3961b91b..f3439fed 100644 --- a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/Metadata.java +++ b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/Metadata.java @@ -40,7 +40,7 @@ public class Metadata { } /** - * Size of the transcripts array + * Size of the transcripts array */ public long getNumTranscripts() { return implJNI.Metadata_NumTranscripts_get(swigCPtr, this); diff --git a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/STT_Error_Codes.java b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/STT_Error_Codes.java index 656e9337..861e2d4c 100644 --- a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/STT_Error_Codes.java +++ b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/STT_Error_Codes.java @@ -70,4 +70,3 @@ public enum STT_Error_Codes { private static int next = 0; } } - diff --git a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/TokenMetadata.java b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/TokenMetadata.java index d2517b28..0f1e45f5 100644 --- a/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/TokenMetadata.java +++ b/native_client/java/libstt/src/main/java/ai/coqui/libstt_doc/TokenMetadata.java @@ -35,21 +35,21 @@ public class TokenMetadata { } /** - * The text corresponding to this token + * The text corresponding to this token */ public String getText() { return implJNI.TokenMetadata_Text_get(swigCPtr, this); } /** - * Position of the token in units of 20ms + * Position of the token in units of 20ms */ public long getTimestep() { return implJNI.TokenMetadata_Timestep_get(swigCPtr, this); } /** - * Position of the token in seconds + * Position of the token in seconds */ public float getStartTime() { return implJNI.TokenMetadata_StartTime_get(swigCPtr, this); diff --git a/native_client/java/libstt/src/test/java/ai/coqui/libstt/ExampleUnitTest.java b/native_client/java/libstt/src/test/java/ai/coqui/libstt/ExampleUnitTest.java index 64262910..43c83b2e 100644 --- a/native_client/java/libstt/src/test/java/ai/coqui/libstt/ExampleUnitTest.java +++ b/native_client/java/libstt/src/test/java/ai/coqui/libstt/ExampleUnitTest.java @@ -14,4 +14,4 @@ public class ExampleUnitTest { public void addition_isCorrect() { assertEquals(4, 2 + 2); } -} \ No newline at end of file +} diff --git a/native_client/javascript/Makefile b/native_client/javascript/Makefile index 70f7686f..d71e12a0 100644 --- a/native_client/javascript/Makefile +++ b/native_client/javascript/Makefile @@ -1,5 +1,5 @@ NODE_BUILD_TOOL ?= node-pre-gyp -NODE_ABI_TARGET ?= +NODE_ABI_TARGET ?= NODE_BUILD_VERBOSE ?= --verbose NPM_TOOL ?= npm PROJECT_NAME ?= stt diff --git a/native_client/javascript/binding.gyp b/native_client/javascript/binding.gyp index 594f1f39..21124cc8 100644 --- a/native_client/javascript/binding.gyp +++ b/native_client/javascript/binding.gyp @@ -1,46 +1,44 @@ { - "targets": [ - { - "target_name": "stt", - "sources": [ "stt_wrap.cxx" ], - "libraries": [ - "$(LIBS)" - ], - "include_dirs": [ - "../" - ], - "conditions": [ - [ "OS=='mac'", { - "xcode_settings": { - "OTHER_CXXFLAGS": [ - "-stdlib=libc++", - "-mmacosx-version-min=10.10" - ], - "OTHER_LDFLAGS": [ - "-stdlib=libc++", - "-mmacosx-version-min=10.10" - ] - } - } - ] - ] - }, - { - "target_name": "action_after_build", - "type": "none", - "dependencies": [ "<(module_name)" ], - "copies": [ + "targets": [ { - "files": [ "<(PRODUCT_DIR)/<(module_name).node" ], - "destination": "<(module_path)" - } - ] - } - ], - "variables": { - "build_v8_with_gn": 0, - "v8_enable_pointer_compression": 0, - "v8_enable_31bit_smis_on_64bit_arch": 0, - "enable_lto": 1 - }, + "target_name": "stt", + "sources": ["stt_wrap.cxx"], + "libraries": ["$(LIBS)"], + "include_dirs": ["../"], + "conditions": [ + [ + "OS=='mac'", + { + "xcode_settings": { + "OTHER_CXXFLAGS": [ + "-stdlib=libc++", + "-mmacosx-version-min=10.10", + ], + "OTHER_LDFLAGS": [ + "-stdlib=libc++", + "-mmacosx-version-min=10.10", + ], + } + }, + ] + ], + }, + { + "target_name": "action_after_build", + "type": "none", + "dependencies": ["<(module_name)"], + "copies": [ + { + "files": ["<(PRODUCT_DIR)/<(module_name).node"], + "destination": "<(module_path)", + } + ], + }, + ], + "variables": { + "build_v8_with_gn": 0, + "v8_enable_pointer_compression": 0, + "v8_enable_31bit_smis_on_64bit_arch": 0, + "enable_lto": 1, + }, } diff --git a/native_client/javascript/index.ts b/native_client/javascript/index.ts index 9ad915fb..be18a8ad 100644 --- a/native_client/javascript/index.ts +++ b/native_client/javascript/index.ts @@ -136,7 +136,7 @@ class StreamImpl { } /** * Exposes the type of Stream without actually exposing the class. - * Because the Stream class should not be instantiated directly, + * Because the Stream class should not be instantiated directly, * but instead be created via :js:func:`Model.createStream`. */ export type Stream = StreamImpl; diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 84fc5ed6..831dd2a4 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -37,7 +37,7 @@ ModelState::decode(const DecoderState& state) const } Metadata* -ModelState::decode_metadata(const DecoderState& state, +ModelState::decode_metadata(const DecoderState& state, size_t num_results) { vector out = state.decode(num_results); diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index 49b5b3d3..e1d37a70 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -1,27 +1,28 @@ import os import platform -#The API is not snake case which triggers linter errors -#pylint: disable=invalid-name +# The API is not snake case which triggers linter errors +# pylint: disable=invalid-name if platform.system().lower() == "windows": - dslib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'lib') + dslib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "lib") # On Windows, we can't rely on RPATH being set to $ORIGIN/lib/ or on # @loader_path/lib - if hasattr(os, 'add_dll_directory'): + if hasattr(os, "add_dll_directory"): # Starting with Python 3.8 this properly handles the problem os.add_dll_directory(dslib_path) else: # Before Pythin 3.8 we need to change the PATH to include the proper # directory for the dynamic linker - os.environ['PATH'] = dslib_path + ';' + os.environ['PATH'] + os.environ["PATH"] = dslib_path + ";" + os.environ["PATH"] import stt # rename for backwards compatibility from stt.impl import Version as version + class Model(object): """ Class holding a Coqui STT model @@ -29,13 +30,18 @@ class Model(object): :param aModelPath: Path to model file to load :type aModelPath: str """ + def __init__(self, model_path): # make sure the attribute is there if CreateModel fails self._impl = None status, impl = stt.impl.CreateModel(model_path) if status != 0: - raise RuntimeError("CreateModel failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) + raise RuntimeError( + "CreateModel failed with '{}' (0x{:X})".format( + stt.impl.ErrorCodeToErrorMessage(status), status + ) + ) self._impl = impl def __del__(self): @@ -85,7 +91,11 @@ class Model(object): """ status = stt.impl.EnableExternalScorer(self._impl, scorer_path) if status != 0: - raise RuntimeError("EnableExternalScorer failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) + raise RuntimeError( + "EnableExternalScorer failed with '{}' (0x{:X})".format( + stt.impl.ErrorCodeToErrorMessage(status), status + ) + ) def disableExternalScorer(self): """ @@ -98,7 +108,7 @@ class Model(object): def addHotWord(self, word, boost): """ Add a word and its boost for decoding. - + Words that don't occur in the scorer (e.g. proper nouns) or strings that contain spaces won't be taken into account. :param word: the hot-word @@ -111,7 +121,11 @@ class Model(object): """ status = stt.impl.AddHotWord(self._impl, word, boost) if status != 0: - raise RuntimeError("AddHotWord failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) + raise RuntimeError( + "AddHotWord failed with '{}' (0x{:X})".format( + stt.impl.ErrorCodeToErrorMessage(status), status + ) + ) def eraseHotWord(self, word): """ @@ -124,7 +138,11 @@ class Model(object): """ status = stt.impl.EraseHotWord(self._impl, word) if status != 0: - raise RuntimeError("EraseHotWord failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) + raise RuntimeError( + "EraseHotWord failed with '{}' (0x{:X})".format( + stt.impl.ErrorCodeToErrorMessage(status), status + ) + ) def clearHotWords(self): """ @@ -134,7 +152,11 @@ class Model(object): """ status = stt.impl.ClearHotWords(self._impl) if status != 0: - raise RuntimeError("ClearHotWords failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) + raise RuntimeError( + "ClearHotWords failed with '{}' (0x{:X})".format( + stt.impl.ErrorCodeToErrorMessage(status), status + ) + ) def setScorerAlphaBeta(self, alpha, beta): """ @@ -190,7 +212,11 @@ class Model(object): """ status, ctx = stt.impl.CreateStream(self._impl) if status != 0: - raise RuntimeError("CreateStream failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) + raise RuntimeError( + "CreateStream failed with '{}' (0x{:X})".format( + stt.impl.ErrorCodeToErrorMessage(status), status + ) + ) return Stream(ctx) @@ -199,6 +225,7 @@ class Stream(object): Class wrapping a stt stream. The constructor cannot be called directly. Use :func:`Model.createStream()` """ + def __init__(self, native_stream): self._impl = native_stream @@ -216,7 +243,9 @@ class Stream(object): :throws: RuntimeError if the stream object is not valid """ if not self._impl: - raise RuntimeError("Stream object is not valid. Trying to feed an already finished stream?") + raise RuntimeError( + "Stream object is not valid. Trying to feed an already finished stream?" + ) stt.impl.FeedAudioContent(self._impl, audio_buffer) def intermediateDecode(self): @@ -229,7 +258,9 @@ class Stream(object): :throws: RuntimeError if the stream object is not valid """ if not self._impl: - raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") + raise RuntimeError( + "Stream object is not valid. Trying to decode an already finished stream?" + ) return stt.impl.IntermediateDecode(self._impl) def intermediateDecodeWithMetadata(self, num_results=1): @@ -245,7 +276,9 @@ class Stream(object): :throws: RuntimeError if the stream object is not valid """ if not self._impl: - raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") + raise RuntimeError( + "Stream object is not valid. Trying to decode an already finished stream?" + ) return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results) def finishStream(self): @@ -260,7 +293,9 @@ class Stream(object): :throws: RuntimeError if the stream object is not valid """ if not self._impl: - raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") + raise RuntimeError( + "Stream object is not valid. Trying to finish an already finished stream?" + ) result = stt.impl.FinishStream(self._impl) self._impl = None return result @@ -281,7 +316,9 @@ class Stream(object): :throws: RuntimeError if the stream object is not valid """ if not self._impl: - raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") + raise RuntimeError( + "Stream object is not valid. Trying to finish an already finished stream?" + ) result = stt.impl.FinishStreamWithMetadata(self._impl, num_results) self._impl = None return result @@ -294,7 +331,9 @@ class Stream(object): :throws: RuntimeError if the stream object is not valid """ if not self._impl: - raise RuntimeError("Stream object is not valid. Trying to free an already finished stream?") + raise RuntimeError( + "Stream object is not valid. Trying to free an already finished stream?" + ) stt.impl.FreeStream(self._impl) self._impl = None @@ -311,13 +350,11 @@ class TokenMetadata(object): The text for this token """ - def timestep(self): """ Position of the token in units of 20ms """ - def start_time(self): """ Position of the token in seconds @@ -328,6 +365,7 @@ class CandidateTranscript(object): """ Stores the entire CTC output as an array of character metadata objects """ + def tokens(self): """ List of tokens @@ -336,7 +374,6 @@ class CandidateTranscript(object): :type: list """ - def confidence(self): """ Approximated confidence value for this transcription. This is roughly the diff --git a/native_client/python/client.py b/native_client/python/client.py index 85290c76..a648e7f2 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -3,16 +3,16 @@ from __future__ import absolute_import, division, print_function import argparse -import numpy as np +import json import shlex import subprocess import sys import wave -import json - -from stt import Model, version from timeit import default_timer as timer +import numpy as np +from stt import Model, version + try: from shhlex import quote except ImportError: @@ -20,19 +20,26 @@ except ImportError: def convert_samplerate(audio_path, desired_sample_rate): - sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) + sox_cmd = "sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - ".format( + quote(audio_path), desired_sample_rate + ) try: output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE) except subprocess.CalledProcessError as e: - raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr)) + raise RuntimeError("SoX returned non-zero status: {}".format(e.stderr)) except OSError as e: - raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(desired_sample_rate, e.strerror)) + raise OSError( + e.errno, + "SoX not found, use {}hz files or install it: {}".format( + desired_sample_rate, e.strerror + ), + ) return desired_sample_rate, np.frombuffer(output, np.int16) def metadata_to_string(metadata): - return ''.join(token.text for token in metadata.tokens) + return "".join(token.text for token in metadata.tokens) def words_from_candidate_transcript(metadata): @@ -70,56 +77,78 @@ def words_from_candidate_transcript(metadata): def metadata_json_output(metadata): json_result = dict() - json_result["transcripts"] = [{ - "confidence": transcript.confidence, - "words": words_from_candidate_transcript(transcript), - } for transcript in metadata.transcripts] + json_result["transcripts"] = [ + { + "confidence": transcript.confidence, + "words": words_from_candidate_transcript(transcript), + } + for transcript in metadata.transcripts + ] return json.dumps(json_result, indent=2) - class VersionAction(argparse.Action): def __init__(self, *args, **kwargs): super(VersionAction, self).__init__(nargs=0, *args, **kwargs) def __call__(self, *args, **kwargs): - print('Coqui STT ', version()) + print("Coqui STT ", version()) exit(0) def main(): - parser = argparse.ArgumentParser(description='Running Coqui STT inference.') - parser.add_argument('--model', required=True, - help='Path to the model (protocol buffer binary file)') - parser.add_argument('--scorer', required=False, - help='Path to the external scorer file') - parser.add_argument('--audio', required=True, - help='Path to the audio file to run (WAV format)') - parser.add_argument('--beam_width', type=int, - help='Beam width for the CTC decoder') - parser.add_argument('--lm_alpha', type=float, - help='Language model weight (lm_alpha). If not specified, use default from the scorer package.') - parser.add_argument('--lm_beta', type=float, - help='Word insertion bonus (lm_beta). If not specified, use default from the scorer package.') - parser.add_argument('--version', action=VersionAction, - help='Print version and exits') - parser.add_argument('--extended', required=False, action='store_true', - help='Output string from extended metadata') - parser.add_argument('--json', required=False, action='store_true', - help='Output json from metadata with timestamp of each word') - parser.add_argument('--candidate_transcripts', type=int, default=3, - help='Number of candidate transcripts to include in JSON output') - parser.add_argument('--hot_words', type=str, - help='Hot-words and their boosts.') + parser = argparse.ArgumentParser(description="Running Coqui STT inference.") + parser.add_argument( + "--model", required=True, help="Path to the model (protocol buffer binary file)" + ) + parser.add_argument( + "--scorer", required=False, help="Path to the external scorer file" + ) + parser.add_argument( + "--audio", required=True, help="Path to the audio file to run (WAV format)" + ) + parser.add_argument("--beam_width", type=int, help="Beam width for the CTC decoder") + parser.add_argument( + "--lm_alpha", + type=float, + help="Language model weight (lm_alpha). If not specified, use default from the scorer package.", + ) + parser.add_argument( + "--lm_beta", + type=float, + help="Word insertion bonus (lm_beta). If not specified, use default from the scorer package.", + ) + parser.add_argument( + "--version", action=VersionAction, help="Print version and exits" + ) + parser.add_argument( + "--extended", + required=False, + action="store_true", + help="Output string from extended metadata", + ) + parser.add_argument( + "--json", + required=False, + action="store_true", + help="Output json from metadata with timestamp of each word", + ) + parser.add_argument( + "--candidate_transcripts", + type=int, + default=3, + help="Number of candidate transcripts to include in JSON output", + ) + parser.add_argument("--hot_words", type=str, help="Hot-words and their boosts.") args = parser.parse_args() - print('Loading model from file {}'.format(args.model), file=sys.stderr) + print("Loading model from file {}".format(args.model), file=sys.stderr) model_load_start = timer() # sphinx-doc: python_ref_model_start ds = Model(args.model) # sphinx-doc: python_ref_model_stop model_load_end = timer() - model_load_start - print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) + print("Loaded model in {:.3}s.".format(model_load_end), file=sys.stderr) if args.beam_width: ds.setBeamWidth(args.beam_width) @@ -127,44 +156,55 @@ def main(): desired_sample_rate = ds.sampleRate() if args.scorer: - print('Loading scorer from files {}'.format(args.scorer), file=sys.stderr) + print("Loading scorer from files {}".format(args.scorer), file=sys.stderr) scorer_load_start = timer() ds.enableExternalScorer(args.scorer) scorer_load_end = timer() - scorer_load_start - print('Loaded scorer in {:.3}s.'.format(scorer_load_end), file=sys.stderr) + print("Loaded scorer in {:.3}s.".format(scorer_load_end), file=sys.stderr) if args.lm_alpha and args.lm_beta: ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta) if args.hot_words: - print('Adding hot-words', file=sys.stderr) - for word_boost in args.hot_words.split(','): - word,boost = word_boost.split(':') - ds.addHotWord(word,float(boost)) + print("Adding hot-words", file=sys.stderr) + for word_boost in args.hot_words.split(","): + word, boost = word_boost.split(":") + ds.addHotWord(word, float(boost)) - fin = wave.open(args.audio, 'rb') + fin = wave.open(args.audio, "rb") fs_orig = fin.getframerate() if fs_orig != desired_sample_rate: - print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs_orig, desired_sample_rate), file=sys.stderr) + print( + "Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.".format( + fs_orig, desired_sample_rate + ), + file=sys.stderr, + ) fs_new, audio = convert_samplerate(args.audio, desired_sample_rate) else: audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) - audio_length = fin.getnframes() * (1/fs_orig) + audio_length = fin.getnframes() * (1 / fs_orig) fin.close() - print('Running inference.', file=sys.stderr) + print("Running inference.", file=sys.stderr) inference_start = timer() # sphinx-doc: python_ref_inference_start if args.extended: print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0])) elif args.json: - print(metadata_json_output(ds.sttWithMetadata(audio, args.candidate_transcripts))) + print( + metadata_json_output(ds.sttWithMetadata(audio, args.candidate_transcripts)) + ) else: print(ds.stt(audio)) # sphinx-doc: python_ref_inference_stop inference_end = timer() - inference_start - print('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length), file=sys.stderr) + print( + "Inference took %0.3fs for %0.3fs audio file." % (inference_end, audio_length), + file=sys.stderr, + ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/native_client/python/setup.py b/native_client/python/setup.py index 2a2ebd63..d83b3204 100755 --- a/native_client/python/setup.py +++ b/native_client/python/setup.py @@ -1,107 +1,122 @@ #! /usr/bin/env python -from setuptools import setup, Extension -from distutils.command.build import build - import os import subprocess import sys +from distutils.command.build import build + +from setuptools import Extension, setup + def main(): try: import numpy + try: numpy_include = numpy.get_include() except AttributeError: numpy_include = numpy.get_numpy_include() except ImportError: - numpy_include = '' - assert 'NUMPY_INCLUDE' in os.environ + numpy_include = "" + assert "NUMPY_INCLUDE" in os.environ def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() - numpy_include = os.getenv('NUMPY_INCLUDE', numpy_include) - numpy_min_ver = os.getenv('NUMPY_DEP_VERSION', '') + numpy_include = os.getenv("NUMPY_INCLUDE", numpy_include) + numpy_min_ver = os.getenv("NUMPY_DEP_VERSION", "") - project_name = 'STT' - if '--project_name' in sys.argv: - project_name_idx = sys.argv.index('--project_name') + project_name = "STT" + if "--project_name" in sys.argv: + project_name_idx = sys.argv.index("--project_name") project_name = sys.argv[project_name_idx + 1] - sys.argv.remove('--project_name') + sys.argv.remove("--project_name") sys.argv.pop(project_name_idx) - with open('../../training/coqui_stt_training/VERSION', 'r') as ver: + with open("../../training/coqui_stt_training/VERSION", "r") as ver: project_version = ver.read().strip() class BuildExtFirst(build): - sub_commands = [('build_ext', build.has_ext_modules), - ('build_py', build.has_pure_modules), - ('build_clib', build.has_c_libraries), - ('build_scripts', build.has_scripts)] + sub_commands = [ + ("build_ext", build.has_ext_modules), + ("build_py", build.has_pure_modules), + ("build_clib", build.has_c_libraries), + ("build_scripts", build.has_scripts), + ] # Properly pass arguments for linking, setuptools will perform some checks def lib_dirs_split(a): - if os.name == 'posix': - return a.split('-L')[1:] + if os.name == "posix": + return a.split("-L")[1:] - if os.name == 'nt': + if os.name == "nt": return [] - raise AssertionError('os.name == java not expected') + raise AssertionError("os.name == java not expected") def libs_split(a): - if os.name == 'posix': - return a.split('-l')[1:] + if os.name == "posix": + return a.split("-l")[1:] - if os.name == 'nt': - return a.split('.lib')[0:1] + if os.name == "nt": + return a.split(".lib")[0:1] - raise AssertionError('os.name == java not expected') + raise AssertionError("os.name == java not expected") - ds_ext = Extension(name='stt._impl', - sources=['impl.i'], - include_dirs=[numpy_include, '../'], - library_dirs=list(map(lambda x: x.strip(), lib_dirs_split(os.getenv('MODEL_LDFLAGS', '')))), - libraries=list(map(lambda x: x.strip(), libs_split(os.getenv('MODEL_LIBS', '')))), - swig_opts=['-c++', '-keyword']) + ds_ext = Extension( + name="stt._impl", + sources=["impl.i"], + include_dirs=[numpy_include, "../"], + library_dirs=list( + map(lambda x: x.strip(), lib_dirs_split(os.getenv("MODEL_LDFLAGS", ""))) + ), + libraries=list( + map(lambda x: x.strip(), libs_split(os.getenv("MODEL_LIBS", ""))) + ), + swig_opts=["-c++", "-keyword"], + ) - setup(name=project_name, - description='A library for doing speech recognition using a Coqui STT model', - long_description=read('README.rst'), - long_description_content_type='text/x-rst; charset=UTF-8', - author='Coqui GmbH', - version=project_version, - package_dir={'stt': '.'}, - cmdclass={'build': BuildExtFirst}, - license='MPL-2.0', - url='https://github.com/coqui-ai/STT', - project_urls={ - 'Documentation': 'https://stt.readthedocs.io', - 'Tracker': 'https://github.com/coqui-ai/STT/issues', - 'Repository': 'https://github.com/coqui-ai/STT/tree/v{}'.format(project_version), - 'Discussions': 'https://github.com/coqui-ai/STT/discussions', - }, - ext_modules=[ds_ext], - py_modules=['stt', 'stt.client', 'stt.impl'], - entry_points={'console_scripts':['stt=stt.client:main']}, - install_requires=['numpy%s' % numpy_min_ver], - include_package_data=True, - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Topic :: Multimedia :: Sound/Audio :: Speech', - 'Topic :: Scientific/Engineering :: Human Machine Interfaces', - 'Topic :: Scientific/Engineering', - 'Topic :: Utilities', - ]) + setup( + name=project_name, + description="A library for doing speech recognition using a Coqui STT model", + long_description=read("README.rst"), + long_description_content_type="text/x-rst; charset=UTF-8", + author="Coqui GmbH", + version=project_version, + package_dir={"stt": "."}, + cmdclass={"build": BuildExtFirst}, + license="MPL-2.0", + url="https://github.com/coqui-ai/STT", + project_urls={ + "Documentation": "https://stt.readthedocs.io", + "Tracker": "https://github.com/coqui-ai/STT/issues", + "Repository": "https://github.com/coqui-ai/STT/tree/v{}".format( + project_version + ), + "Discussions": "https://github.com/coqui-ai/STT/discussions", + }, + ext_modules=[ds_ext], + py_modules=["stt", "stt.client", "stt.impl"], + entry_points={"console_scripts": ["stt=stt.client:main"]}, + install_requires=["numpy%s" % numpy_min_ver], + include_package_data=True, + classifiers=[ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Scientific/Engineering :: Human Machine Interfaces", + "Topic :: Scientific/Engineering", + "Topic :: Utilities", + ], + ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/native_client/stt.cc b/native_client/stt.cc index 80c6335e..28715ec5 100644 --- a/native_client/stt.cc +++ b/native_client/stt.cc @@ -474,7 +474,7 @@ STT_FinishStream(StreamingState* aSctx) } Metadata* -STT_FinishStreamWithMetadata(StreamingState* aSctx, +STT_FinishStreamWithMetadata(StreamingState* aSctx, unsigned int aNumResults) { Metadata* result = aSctx->finishStreamWithMetadata(aNumResults); diff --git a/native_client/swift/stt_ios/stt_ios.h b/native_client/swift/stt_ios/stt_ios.h index 972d5c2b..0f35bd04 100644 --- a/native_client/swift/stt_ios/stt_ios.h +++ b/native_client/swift/stt_ios/stt_ios.h @@ -9,5 +9,3 @@ #import // In this header, you should import all the public headers of your framework using statements like #import - - diff --git a/native_client/swift/stt_ios/stt_ios.modulemap b/native_client/swift/stt_ios/stt_ios.modulemap index b5a1fc51..47af9452 100644 --- a/native_client/swift/stt_ios/stt_ios.modulemap +++ b/native_client/swift/stt_ios/stt_ios.modulemap @@ -3,7 +3,7 @@ framework module stt_ios { export * module * { export * } - + explicit module libstt_Private { header "coqui-stt.h" export * diff --git a/native_client/swift/stt_ios_test/ContentView.swift b/native_client/swift/stt_ios_test/ContentView.swift index 7a45c7b5..a1892960 100644 --- a/native_client/swift/stt_ios_test/ContentView.swift +++ b/native_client/swift/stt_ios_test/ContentView.swift @@ -11,7 +11,7 @@ import SwiftUI struct ContentView: View { private var stt = SpeechRecognitionImpl() @State var isRecognizingMicrophone = false - + var body: some View { VStack { Text("Coqui STT iOS Demo") @@ -28,16 +28,16 @@ struct ContentView: View { .padding(30) } } - + func recognizeFiles() { self.stt.recognizeFiles() } - + func startMicRecognition() { isRecognizingMicrophone = true self.stt.startMicrophoneRecognition() } - + func stopMicRecognition() { isRecognizingMicrophone = false self.stt.stopMicrophoneRecognition() diff --git a/native_client/swift/stt_ios_test/SceneDelegate.swift b/native_client/swift/stt_ios_test/SceneDelegate.swift index d0cb01eb..e25708b9 100644 --- a/native_client/swift/stt_ios_test/SceneDelegate.swift +++ b/native_client/swift/stt_ios_test/SceneDelegate.swift @@ -62,4 +62,3 @@ class SceneDelegate: UIResponder, UIWindowSceneDelegate { } - diff --git a/native_client/swift/stt_ios_test/SpeechRecognitionImpl.swift b/native_client/swift/stt_ios_test/SpeechRecognitionImpl.swift index f754f122..fef3dc5c 100644 --- a/native_client/swift/stt_ios_test/SpeechRecognitionImpl.swift +++ b/native_client/swift/stt_ios_test/SpeechRecognitionImpl.swift @@ -21,58 +21,58 @@ struct FillComplexInputParm { class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDelegate { private var model: STTModel private var stream: STTStream? - + private var captureSession = AVCaptureSession() private var audioData = Data() - + override init() { let modelPath = Bundle.main.path(forResource: "coqui-stt-0.9.3-models", ofType: "tflite")! let scorerPath = Bundle.main.path(forResource: "coqui-stt-0.9.3-models", ofType: "scorer")! model = try! STTModel(modelPath: modelPath) try! model.enableExternalScorer(scorerPath: scorerPath) - + super.init() - + // prepare audio capture self.configureCaptureSession() } - + // MARK: Microphone recognition - + private func configureCaptureSession() { captureSession.beginConfiguration() - + let audioDevice = AVCaptureDevice.default(.builtInMicrophone, for: .audio, position: .unspecified) - + let audioDeviceInput = try! AVCaptureDeviceInput(device: audioDevice!) guard captureSession.canAddInput(audioDeviceInput) else { return } captureSession.addInput(audioDeviceInput) - + let serialQueue = DispatchQueue(label: "serialQueue") let audioOutput = AVCaptureAudioDataOutput() audioOutput.setSampleBufferDelegate(self, queue: serialQueue) - + guard captureSession.canAddOutput(audioOutput) else { return } captureSession.sessionPreset = .inputPriority captureSession.addOutput(audioOutput) captureSession.commitConfiguration() } - + func captureOutput(_ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) { var sourceFormat = (sampleBuffer.formatDescription?.audioFormatList[0].mASBD)! var destinationFormat = sourceFormat destinationFormat.mSampleRate = 16000.0 - + var audioConverterRef: AudioConverterRef? let createConverterStatus = AudioConverterNew(&sourceFormat, &destinationFormat, &audioConverterRef) - + if (createConverterStatus != noErr) { print("Error creating converter") } - + var quality = kAudioConverterQuality_Max - + AudioConverterSetProperty(audioConverterRef!, kAudioConverterSampleRateConverterQuality, UInt32(MemoryLayout.size), &quality) let blockBuffer = CMSampleBufferGetDataBuffer(sampleBuffer) @@ -80,15 +80,15 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele var pcmLength: Int = 0 var pcmData: UnsafeMutablePointer? let status: OSStatus = CMBlockBufferGetDataPointer(blockBuffer!, atOffset: 0, lengthAtOffsetOut: nil, totalLengthOut: &pcmLength, dataPointerOut: &pcmData) - + if status != noErr { print("Error getting something") } else { var input = FillComplexInputParm(source: pcmData!, sourceSize: UInt32(pcmLength)) - + let outputBuffer = malloc(pcmLength) memset(outputBuffer, 0, pcmLength); - + var outputBufferList = AudioBufferList() outputBufferList.mNumberBuffers = 1 outputBufferList.mBuffers.mData = outputBuffer @@ -103,56 +103,56 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele inUserData: UnsafeMutableRawPointer? ) -> OSStatus { var inputPtr = inUserData!.load(as: FillComplexInputParm.self) - + if (inputPtr.sourceSize <= 0) { ioNumberDataPacket.pointee = 1 return -1 } - + let rawPtr = UnsafeMutableRawPointer(inputPtr.source) - + ioData.pointee.mNumberBuffers = 1 ioData.pointee.mBuffers.mData = rawPtr ioData.pointee.mBuffers.mDataByteSize = inputPtr.sourceSize ioData.pointee.mBuffers.mNumberChannels = 1 - + ioNumberDataPacket.pointee = (inputPtr.sourceSize / 2) inputPtr.sourceSize = 0 - + return noErr }; - + var packetSize: UInt32 = UInt32(pcmLength / 2) - + let status: OSStatus = AudioConverterFillComplexBuffer(audioConverterRef!, inputDataProc, &input, &packetSize, &outputBufferList, nil) - + if (status != noErr) { print("Error: " + status.description) } else { let data = outputBufferList.mBuffers.mData! let byteSize = outputBufferList.mBuffers.mDataByteSize - + let shorts = UnsafeBufferPointer(start: data.assumingMemoryBound(to: Int16.self), count: Int(byteSize / 2)) stream!.feedAudioContent(buffer: shorts) - + // save bytes to audio data for creating a pcm file later for the captured audio let ptr = UnsafePointer(data.assumingMemoryBound(to: UInt8.self)) audioData.append(ptr, count: Int(byteSize)) } - + free(outputBuffer) AudioConverterDispose(audioConverterRef!) } } - - + + public func startMicrophoneRecognition() { audioData = Data() stream = try! model.createStream() captureSession.startRunning() print("Started listening...") } - + private func writeAudioDataToPCMFile() { let documents = NSSearchPathForDirectoriesInDomains(FileManager.SearchPathDirectory.documentDirectory, FileManager.SearchPathDomainMask.userDomainMask, true)[0] let filePath = documents + "/recording.pcm" @@ -160,19 +160,19 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele try! audioData.write(to: url) print("Saved audio to " + filePath) } - + public func stopMicrophoneRecognition() { captureSession.stopRunning() - + let result = stream?.finishStream() print("Result: " + result!) - + // optional, useful for checking the recorded audio writeAudioDataToPCMFile() } - + // MARK: Audio file recognition - + private func render(audioContext: AudioContext?, stream: STTStream) { guard let audioContext = audioContext else { fatalError("Couldn't create the audioContext") @@ -239,7 +239,7 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele fatalError("Couldn't read the audio file") } } - + private func recognizeFile(audioPath: String, completion: @escaping () -> ()) { let url = URL(fileURLWithPath: audioPath) @@ -257,7 +257,7 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele completion() }) } - + public func recognizeFiles() { // Add file names (without extension) here if you want to test recognition from files. // Remember to add them to the project under Copy Bundle Resources. @@ -266,7 +266,7 @@ class SpeechRecognitionImpl : NSObject, AVCaptureAudioDataOutputSampleBufferDele let serialQueue = DispatchQueue(label: "serialQueue") let group = DispatchGroup() group.enter() - + if let first = files.first { serialQueue.async { self.recognizeFile(audioPath: Bundle.main.path(forResource: first, ofType: "wav")!) { diff --git a/native_client/test/concurrent_streams.py b/native_client/test/concurrent_streams.py index 8bce5807..e5146b97 100644 --- a/native_client/test/concurrent_streams.py +++ b/native_client/test/concurrent_streams.py @@ -3,22 +3,26 @@ from __future__ import absolute_import, division, print_function import argparse -import numpy as np import wave +import numpy as np from stt import Model def main(): - parser = argparse.ArgumentParser(description='Running STT inference.') - parser.add_argument('--model', required=True, - help='Path to the model (protocol buffer binary file)') - parser.add_argument('--scorer', nargs='?', - help='Path to the external scorer file') - parser.add_argument('--audio1', required=True, - help='First audio file to use in interleaved streams') - parser.add_argument('--audio2', required=True, - help='Second audio file to use in interleaved streams') + parser = argparse.ArgumentParser(description="Running STT inference.") + parser.add_argument( + "--model", required=True, help="Path to the model (protocol buffer binary file)" + ) + parser.add_argument("--scorer", nargs="?", help="Path to the external scorer file") + parser.add_argument( + "--audio1", required=True, help="First audio file to use in interleaved streams" + ) + parser.add_argument( + "--audio2", + required=True, + help="Second audio file to use in interleaved streams", + ) args = parser.parse_args() ds = Model(args.model) @@ -26,12 +30,12 @@ def main(): if args.scorer: ds.enableExternalScorer(args.scorer) - fin = wave.open(args.audio1, 'rb') + fin = wave.open(args.audio1, "rb") fs1 = fin.getframerate() audio1 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) fin.close() - fin = wave.open(args.audio2, 'rb') + fin = wave.open(args.audio2, "rb") fs2 = fin.getframerate() audio2 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) fin.close() @@ -49,5 +53,6 @@ def main(): print(stream1.finishStream()) print(stream2.finishStream()) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/native_client/xldd b/native_client/xldd index ef2ca20f..a9353884 100755 --- a/native_client/xldd +++ b/native_client/xldd @@ -6,7 +6,7 @@ # crosstool-ng 1.22.0 # In order to use it, copy it in same directory than other -# toolchain binaries and rename it with same tuple. +# toolchain binaries and rename it with same tuple. # (i.e. /opt/arm-sysmic-linux-gnueabihf/bin/arm-sysmic-linux-gnueabihf-ldd) # Thus, this will automaticaly detect necessary information # about your toolchain. diff --git a/parse_valgrind_suppressions.sh b/parse_valgrind_suppressions.sh index 1a3769ed..681b8211 100755 --- a/parse_valgrind_suppressions.sh +++ b/parse_valgrind_suppressions.sh @@ -13,7 +13,7 @@ BEGIN { suppression=0; md5sum = "md5sum" } # If the line begins with '{', it's the start of a supression; so set the var and initialise things /^{/ { - suppression=1; i=0; next + suppression=1; i=0; next } # If the line begins with '}' its the end of a suppression /^}/ { @@ -26,7 +26,7 @@ BEGIN { suppression=0; md5sum = "md5sum" } } # Otherwise, it's a normal line. If we're inside a supression, store it, and pipe it to md5sum. Otherwise it's cruft, so ignore it { if (suppression) - { + { supparray[++i] = $0 print |& md5sum } @@ -35,7 +35,7 @@ BEGIN { suppression=0; md5sum = "md5sum" } function ProcessInput() { - # Pipe the result from md5sum, then close it + # Pipe the result from md5sum, then close it md5sum |& getline result close(md5sum) # gawk can't cope with enormous ints like $result would be, so stringify it first by prefixing a definite string @@ -49,9 +49,9 @@ BEGIN { suppression=0; md5sum = "md5sum" } function OutputSuppression() { - # A suppression is surrounded by '{' and '}'. Its data was stored line by line in the array - print "{" + # A suppression is surrounded by '{' and '}'. Its data was stored line by line in the array + print "{" for (n=1; n <= i; ++n) { print supparray[n] } - print "}" + print "}" } diff --git a/setup.py b/setup.py index 565a657f..b9bec546 100644 --- a/setup.py +++ b/setup.py @@ -8,77 +8,74 @@ from setuptools import find_packages, setup def main(): - version_file = Path(__file__).parent / 'VERSION' + version_file = Path(__file__).parent / "VERSION" with open(str(version_file)) as fin: version = fin.read().strip() install_requires_base = [ - 'absl-py', - 'attrdict', - 'bs4', - 'numpy', - 'optuna', - 'opuslib == 2.0.0', - 'pandas', - 'progressbar2', - 'pyogg >= 0.6.14a1', - 'pyxdg', - 'resampy >= 0.2.2', - 'requests', - 'semver', - 'six', - 'sox', - 'soundfile', + "absl-py", + "attrdict", + "bs4", + "numpy", + "optuna", + "opuslib == 2.0.0", + "pandas", + "progressbar2", + "pyogg >= 0.6.14a1", + "pyxdg", + "resampy >= 0.2.2", + "requests", + "semver", + "six", + "sox", + "soundfile", ] - decoder_pypi_dep = [ - 'coqui_stt_ctcdecoder == {}'.format(version) - ] + decoder_pypi_dep = ["coqui_stt_ctcdecoder == {}".format(version)] - tensorflow_pypi_dep = [ - 'tensorflow == 1.15.4' - ] + tensorflow_pypi_dep = ["tensorflow == 1.15.4"] - if os.environ.get('DS_NODECODER', ''): + if os.environ.get("DS_NODECODER", ""): install_requires = install_requires_base else: install_requires = install_requires_base + decoder_pypi_dep - if os.environ.get('DS_NOTENSORFLOW', ''): + if os.environ.get("DS_NOTENSORFLOW", ""): install_requires = install_requires else: install_requires = install_requires + tensorflow_pypi_dep setup( - name='coqui_stt_training', + name="coqui_stt_training", version=version, - description='Training code for Coqui STT', - url='https://github.com/coqui-ai/STT', - author='Coqui STT authors', - license='MPL-2.0', + description="Training code for Coqui STT", + url="https://github.com/coqui-ai/STT", + author="Coqui STT authors", + license="MPL-2.0", # Classifiers help users find your project by categorizing it. # # For a list of valid classifiers, see https://pypi.org/classifiers/ classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Multimedia :: Sound/Audio :: Speech', - 'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', - 'Programming Language :: Python :: 3', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", + "Programming Language :: Python :: 3", ], - package_dir={'': 'training'}, - packages=find_packages(where='training'), - python_requires='>=3.5, <4', + package_dir={"": "training"}, + packages=find_packages(where="training"), + python_requires=">=3.5, <4", install_requires=install_requires, # If there are data files included in your packages that need to be # installed, specify them here. package_data={ - 'coqui_stt_training': [ - 'VERSION', - 'GRAPH_VERSION', + "coqui_stt_training": [ + "VERSION", + "GRAPH_VERSION", ], }, ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/stats.py b/stats.py index dbd8ecd3..8ff0d67e 100644 --- a/stats.py +++ b/stats.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 import argparse import functools -import pandas - -from coqui_stt_training.util.helpers import secs_to_hours from pathlib import Path +import pandas +from coqui_stt_training.util.helpers import secs_to_hours + def read_csvs(csv_files): # Relative paths are relative to CSV location @@ -17,32 +17,59 @@ def read_csvs(csv_files): sets = [] for csv in csv_files: - file = pandas.read_csv(csv, encoding='utf-8', na_filter=False) - file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv)) + file = pandas.read_csv(csv, encoding="utf-8", na_filter=False) + file["wav_filename"] = file["wav_filename"].apply( + functools.partial(absolutify, csv) + ) sets.append(file) # Concat all sets, drop any extra columns, re-index the final result as 0..N - return pandas.concat(sets, join='inner', ignore_index=True) + return pandas.concat(sets, join="inner", ignore_index=True) def main(): parser = argparse.ArgumentParser() - parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True) - parser.add_argument("--sample-rate", type=int, default=16000, required=False, help="Audio sample rate") - parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels") - parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample") + parser.add_argument( + "-csv", + "--csv-files", + help="Str. Filenames as a comma separated list", + required=True, + ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + required=False, + help="Audio sample rate", + ) + parser.add_argument( + "--channels", type=int, default=1, required=False, help="Audio channels" + ) + parser.add_argument( + "--bits-per-sample", + type=int, + default=16, + required=False, + help="Audio bits per sample", + ) args = parser.parse_args() in_files = [Path(i).absolute() for i in args.csv_files.split(",")] csv_dataframe = read_csvs(in_files) - total_bytes = csv_dataframe['wav_filesize'].sum() + total_bytes = csv_dataframe["wav_filesize"].sum() total_files = len(csv_dataframe) - total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum() + total_seconds = ( + (csv_dataframe["wav_filesize"] - 44) + / args.sample_rate + / args.channels + / (args.bits_per_sample // 8) + ).sum() - print('Total bytes:', total_bytes) - print('Total files:', total_files) - print('Total time:', secs_to_hours(total_seconds)) + print("Total bytes:", total_bytes) + print("Total files:", total_files) + print("Total time:", secs_to_hours(total_seconds)) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/tests/test_data/alphabet_windows.txt b/tests/test_data/alphabet_windows.txt index 61b1b24d..b5eff572 100644 --- a/tests/test_data/alphabet_windows.txt +++ b/tests/test_data/alphabet_windows.txt @@ -1,4 +1,3 @@ a b c - diff --git a/tests/test_importers.py b/tests/test_importers.py index a897d0fd..53607d8e 100644 --- a/tests/test_importers.py +++ b/tests/test_importers.py @@ -1,42 +1,49 @@ import unittest - from argparse import Namespace -from coqui_stt_training.util.importers import validate_label_eng, get_validate_label from pathlib import Path +from coqui_stt_training.util.importers import get_validate_label, validate_label_eng + + def from_here(path): here = Path(__file__) return here.parent / path + class TestValidateLabelEng(unittest.TestCase): def test_numbers(self): label = validate_label_eng("this is a 1 2 3 test") self.assertEqual(label, None) -class TestGetValidateLabel(unittest.TestCase): +class TestGetValidateLabel(unittest.TestCase): def test_no_validate_label_locale(self): f = get_validate_label(Namespace()) - self.assertEqual(f('toto'), 'toto') - self.assertEqual(f('toto1234'), None) - self.assertEqual(f('toto1234[{[{[]'), None) + self.assertEqual(f("toto"), "toto") + self.assertEqual(f("toto1234"), None) + self.assertEqual(f("toto1234[{[{[]"), None) def test_validate_label_locale_default(self): f = get_validate_label(Namespace(validate_label_locale=None)) - self.assertEqual(f('toto'), 'toto') - self.assertEqual(f('toto1234'), None) - self.assertEqual(f('toto1234[{[{[]'), None) + self.assertEqual(f("toto"), "toto") + self.assertEqual(f("toto1234"), None) + self.assertEqual(f("toto1234[{[{[]"), None) def test_get_validate_label_missing(self): - args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py')) + args = Namespace( + validate_label_locale=from_here("test_data/validate_locale_ger.py") + ) f = get_validate_label(args) self.assertEqual(f, None) def test_get_validate_label(self): - args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py')) + args = Namespace( + validate_label_locale=from_here("test_data/validate_locale_fra.py") + ) f = get_validate_label(args) - l = f('toto') - self.assertEqual(l, 'toto') + l = f("toto") + self.assertEqual(l, "toto") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_text.py b/tests/test_text.py index ae6a18e0..6cc72940 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -1,13 +1,13 @@ -import unittest import os +import unittest from coqui_stt_ctcdecoder import Alphabet -class TestAlphabetParsing(unittest.TestCase): +class TestAlphabetParsing(unittest.TestCase): def _ending_tester(self, file, expected): - alphabet = Alphabet(os.path.join(os.path.dirname(__file__), 'test_data', file)) - label = '' + alphabet = Alphabet(os.path.join(os.path.dirname(__file__), "test_data", file)) + label = "" label_id = -1 for expected_label, expected_label_id in expected: try: @@ -22,13 +22,14 @@ class TestAlphabetParsing(unittest.TestCase): self.assertEqual(label, expected_label) def test_macos_ending(self): - self._ending_tester('alphabet_macos.txt', [('a', 0), ('b', 1), ('c', 2)]) + self._ending_tester("alphabet_macos.txt", [("a", 0), ("b", 1), ("c", 2)]) def test_unix_ending(self): - self._ending_tester('alphabet_unix.txt', [('a', 0), ('b', 1), ('c', 2)]) + self._ending_tester("alphabet_unix.txt", [("a", 0), ("b", 1), ("c", 2)]) def test_windows_ending(self): - self._ending_tester('alphabet_windows.txt', [('a', 0), ('b', 1), ('c', 2)]) + self._ending_tester("alphabet_windows.txt", [("a", 0), ("b", 1), ("c", 2)]) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_value_range.py b/tests/test_value_range.py index d10c0029..dc9b4991 100644 --- a/tests/test_value_range.py +++ b/tests/test_value_range.py @@ -1,27 +1,32 @@ import unittest import numpy as np +from coqui_stt_training.util.helpers import ( + ValueRange, + get_value_range, + pick_value_from_range, + tf_pick_value_from_range, +) + import tensorflow as tf -from coqui_stt_training.util.helpers import ValueRange, get_value_range, pick_value_from_range, tf_pick_value_from_range class TestValueRange(unittest.TestCase): - def _ending_tester(self, value, value_type, expected): result = get_value_range(value, value_type) self.assertEqual(result, expected) def test_int_str_scalar(self): - self._ending_tester('1', int, ValueRange(1, 1, 0)) + self._ending_tester("1", int, ValueRange(1, 1, 0)) def test_int_str_scalar_radius(self): - self._ending_tester('1~3', int, ValueRange(1, 1, 3)) + self._ending_tester("1~3", int, ValueRange(1, 1, 3)) def test_int_str_range(self): - self._ending_tester('1:2', int, ValueRange(1, 2, 0)) + self._ending_tester("1:2", int, ValueRange(1, 2, 0)) def test_int_str_range_radius(self): - self._ending_tester('1:2~3', int, ValueRange(1, 2, 3)) + self._ending_tester("1:2~3", int, ValueRange(1, 2, 3)) def test_int_scalar(self): self._ending_tester(1, int, ValueRange(1, 1, 0)) @@ -33,16 +38,16 @@ class TestValueRange(unittest.TestCase): self._ending_tester((1, 2, 3), int, ValueRange(1, 2, 3)) def test_float_str_scalar(self): - self._ending_tester('1.0', float, ValueRange(1.0, 1.0, 0.0)) + self._ending_tester("1.0", float, ValueRange(1.0, 1.0, 0.0)) def test_float_str_scalar_radius(self): - self._ending_tester('1.0~3.0', float, ValueRange(1.0, 1.0, 3.0)) + self._ending_tester("1.0~3.0", float, ValueRange(1.0, 1.0, 3.0)) def test_float_str_range(self): - self._ending_tester('1.0:2.0', float, ValueRange(1.0, 2.0, 0.0)) + self._ending_tester("1.0:2.0", float, ValueRange(1.0, 2.0, 0.0)) def test_float_str_range_radius(self): - self._ending_tester('1.0:2.0~3.0', float, ValueRange(1.0, 2.0, 3.0)) + self._ending_tester("1.0:2.0~3.0", float, ValueRange(1.0, 2.0, 3.0)) def test_float_scalar(self): self._ending_tester(1.0, float, ValueRange(1.0, 1.0, 0.0)) @@ -61,7 +66,7 @@ class TestPickValueFromFixedRange(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestPickValueFromFixedRange, self).__init__(*args, **kwargs) self.session = tf.Session() - self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock') + self.clock_ph = tf.placeholder(dtype=tf.float64, name="clock") def _ending_tester(self, value_range, clock, expected): with tf.Session() as session: @@ -71,7 +76,10 @@ class TestPickValueFromFixedRange(unittest.TestCase): return session.run(tf_pick, feed_dict={self.clock_ph: c}) is_int = isinstance(value_range.start, int) - for pick, int_type, float_type in [(pick_value_from_range, int, float), (run_pick, np.int32, np.float32)]: + for pick, int_type, float_type in [ + (pick_value_from_range, int, float), + (run_pick, np.int32, np.float32), + ]: result = pick(value_range, clock) self.assertEqual(result, expected) self.assertTrue(isinstance(result, int_type if is_int else float_type)) @@ -99,9 +107,11 @@ class TestPickValueFromRandomizedRange(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestPickValueFromRandomizedRange, self).__init__(*args, **kwargs) self.session = tf.Session() - self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock') + self.clock_ph = tf.placeholder(dtype=tf.float64, name="clock") - def _ending_tester(self, value_range, clock_min, clock_max, expected_min, expected_max): + def _ending_tester( + self, value_range, clock_min, clock_max, expected_min, expected_max + ): with self.session as session: tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph) @@ -109,12 +119,26 @@ class TestPickValueFromRandomizedRange(unittest.TestCase): return session.run(tf_pick, feed_dict={self.clock_ph: c}) is_int = isinstance(value_range.start, int) - clock_range = np.arange(clock_min, clock_max, (clock_max - clock_min) / 100.0) - for pick, int_type, float_type in [(pick_value_from_range, int, float), (run_pick, np.int32, np.float32)]: + clock_range = np.arange( + clock_min, clock_max, (clock_max - clock_min) / 100.0 + ) + for pick, int_type, float_type in [ + (pick_value_from_range, int, float), + (run_pick, np.int32, np.float32), + ]: results = [pick(value_range, c) for c in clock_range] self.assertGreater(len(set(results)), 80) - self.assertTrue(all(map(lambda x: expected_min <= x <= expected_max, results))) - self.assertTrue(all(map(lambda x: isinstance(x, int_type if is_int else float_type), results))) + self.assertTrue( + all(map(lambda x: expected_min <= x <= expected_max, results)) + ) + self.assertTrue( + all( + map( + lambda x: isinstance(x, int_type if is_int else float_type), + results, + ) + ) + ) def test_int_0(self): self._ending_tester(ValueRange(10000, 30000, 10000), 0.0, 0.1, 0, 22000) @@ -126,14 +150,20 @@ class TestPickValueFromRandomizedRange(unittest.TestCase): self._ending_tester(ValueRange(10000, 30000, 10000), 0.8, 1.0, 16000, 40000) def test_float_0(self): - self._ending_tester(ValueRange(10000.0, 30000.0, 10000.0), 0.0, 0.1, 0.0, 22000.0) + self._ending_tester( + ValueRange(10000.0, 30000.0, 10000.0), 0.0, 0.1, 0.0, 22000.0 + ) def test_float_half(self): - self._ending_tester(ValueRange(10000.0, 30000.0, 10000.0), 0.4, 0.6, 8000.0, 32000.0) + self._ending_tester( + ValueRange(10000.0, 30000.0, 10000.0), 0.4, 0.6, 8000.0, 32000.0 + ) def test_float_1(self): - self._ending_tester(ValueRange(10000.0, 30000.0, 10000.0), 0.8, 1.0, 16000.0, 40000.0) + self._ending_tester( + ValueRange(10000.0, 30000.0, 10000.0), 0.8, 1.0, 16000.0, 40000.0 + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/train.py b/train.py index caf4e1d4..6fd3bb7a 100755 --- a/train.py +++ b/train.py @@ -2,11 +2,11 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -if __name__ == '__main__': +if __name__ == "__main__": try: from coqui_stt_training import train as ds_train except ImportError: - print('Training package is not installed. See training documentation.') + print("Training package is not installed. See training documentation.") raise ds_train.run_script() diff --git a/training/coqui_stt_training/evaluate.py b/training/coqui_stt_training/evaluate.py index 5298f3af..bf6ea914 100755 --- a/training/coqui_stt_training/evaluate.py +++ b/training/coqui_stt_training/evaluate.py @@ -4,33 +4,36 @@ from __future__ import absolute_import, division, print_function import json import sys - from multiprocessing import cpu_count import absl.app import progressbar +from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch +from six.moves import zip + import tensorflow as tf import tensorflow.compat.v1 as tfv1 -from coqui_stt_ctcdecoder import ctc_beam_search_decoder_batch, Scorer -from six.moves import zip from .util.augmentations import NormalizeSampleRate -from .util.config import Config, initialize_globals from .util.checkpoints import load_graph_for_evaluation +from .util.config import Config, initialize_globals from .util.evaluate_tools import calculate_and_print_report, save_samples_json from .util.feeding import create_dataset -from .util.flags import create_flags, FLAGS +from .util.flags import FLAGS, create_flags from .util.helpers import check_ctcdecoder_version from .util.logging import create_progressbar, log_error, log_progress check_ctcdecoder_version() + def sparse_tensor_value_to_texts(value, alphabet): r""" Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings representing its values, converting tokens to strings using ``alphabet``. """ - return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet) + return sparse_tuple_to_texts( + (value.indices, value.values, value.dense_shape), alphabet + ) def sparse_tuple_to_texts(sp_tuple, alphabet): @@ -45,36 +48,42 @@ def sparse_tuple_to_texts(sp_tuple, alphabet): def evaluate(test_csvs, create_model): if FLAGS.scorer_path: - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, - FLAGS.scorer_path, Config.alphabet) + scorer = Scorer( + FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet + ) else: scorer = None - test_sets = [create_dataset([csv], - batch_size=FLAGS.test_batch_size, - train_phase=False, - augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], - reverse=FLAGS.reverse_test, - limit=FLAGS.limit_test) for csv in test_csvs] - iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), - tfv1.data.get_output_shapes(test_sets[0]), - output_classes=tfv1.data.get_output_classes(test_sets[0])) + test_sets = [ + create_dataset( + [csv], + batch_size=FLAGS.test_batch_size, + train_phase=False, + augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], + reverse=FLAGS.reverse_test, + limit=FLAGS.limit_test, + ) + for csv in test_csvs + ] + iterator = tfv1.data.Iterator.from_structure( + tfv1.data.get_output_types(test_sets[0]), + tfv1.data.get_output_shapes(test_sets[0]), + output_classes=tfv1.data.get_output_classes(test_sets[0]), + ) test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 - logits, _ = create_model(batch_x=batch_x, - seq_length=batch_x_len, - dropout=no_dropout) + logits, _ = create_model( + batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout + ) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) - loss = tfv1.nn.ctc_loss(labels=batch_y, - inputs=logits, - sequence_length=batch_x_len) + loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() @@ -93,9 +102,11 @@ def evaluate(test_csvs, create_model): predictions = [] ground_truths = [] - bar = create_progressbar(prefix='Test epoch | ', - widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start() - log_progress('Test epoch...') + bar = create_progressbar( + prefix="Test epoch | ", + widgets=["Steps: ", progressbar.Counter(), " | ", progressbar.Timer()], + ).start() + log_progress("Test epoch...") step_count = 0 @@ -105,17 +116,35 @@ def evaluate(test_csvs, create_model): # First pass, compute losses and transposed logits for decoding while True: try: - batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ - session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) + ( + batch_wav_filenames, + batch_logits, + batch_loss, + batch_lengths, + batch_transcripts, + ) = session.run( + [batch_wav_filename, transposed, loss, batch_x_len, batch_y] + ) except tf.errors.OutOfRangeError: break - decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, - num_processes=num_processes, scorer=scorer, - cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) + decoded = ctc_beam_search_decoder_batch( + batch_logits, + batch_lengths, + Config.alphabet, + FLAGS.beam_width, + num_processes=num_processes, + scorer=scorer, + cutoff_prob=FLAGS.cutoff_prob, + cutoff_top_n=FLAGS.cutoff_top_n, + ) predictions.extend(d[0][1] for d in decoded) - ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) - wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) + ground_truths.extend( + sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet) + ) + wav_filenames.extend( + wav_filename.decode("UTF-8") for wav_filename in batch_wav_filenames + ) losses.extend(batch_loss) step_count += 1 @@ -124,12 +153,14 @@ def evaluate(test_csvs, create_model): bar.finish() # Print test summary - test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) + test_samples = calculate_and_print_report( + wav_filenames, ground_truths, predictions, losses, dataset + ) return test_samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): - print('Testing model on {}'.format(csv)) + print("Testing model on {}".format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples @@ -138,12 +169,17 @@ def main(_): initialize_globals() if not FLAGS.test_files: - log_error('You need to specify what files to use for evaluation via ' - 'the --test_files flag.') + log_error( + "You need to specify what files to use for evaluation via " + "the --test_files flag." + ) sys.exit(1) - from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel - samples = evaluate(FLAGS.test_files.split(','), create_model) + from .train import ( # pylint: disable=cyclic-import,import-outside-toplevel + create_model, + ) + + samples = evaluate(FLAGS.test_files.split(","), create_model) if FLAGS.test_output_file: save_samples_json(samples, FLAGS.test_output_file) @@ -153,5 +189,6 @@ def run_script(): create_flags() absl.app.run(main) -if __name__ == '__main__': + +if __name__ == "__main__": run_script() diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 77d32487..15378b6a 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -5,44 +5,70 @@ from __future__ import absolute_import, division, print_function import os import sys -LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 -DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' -os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL +LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 +DESIRED_LOG_LEVEL = ( + sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3" +) +os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL + +import shutil +import time import absl.app import numpy as np import progressbar -import shutil + import tensorflow as tf import tensorflow.compat.v1 as tfv1 -import time -tfv1.logging.set_verbosity({ - '0': tfv1.logging.DEBUG, - '1': tfv1.logging.INFO, - '2': tfv1.logging.WARN, - '3': tfv1.logging.ERROR -}.get(DESIRED_LOG_LEVEL)) +tfv1.logging.set_verbosity( + { + "0": tfv1.logging.DEBUG, + "1": tfv1.logging.INFO, + "2": tfv1.logging.WARN, + "3": tfv1.logging.ERROR, + }.get(DESIRED_LOG_LEVEL) +) from datetime import datetime -from coqui_stt_ctcdecoder import ctc_beam_search_decoder, Scorer + +from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder +from six.moves import range, zip + from .evaluate import evaluate -from six.moves import zip, range from .util.augmentations import NormalizeSampleRate +from .util.checkpoints import ( + load_graph_for_evaluation, + load_or_init_graph_for_training, + reload_best_checkpoint, +) from .util.config import Config, initialize_globals -from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint from .util.evaluate_tools import save_samples_json -from .util.feeding import create_dataset, audio_to_features, audiofile_to_features -from .util.flags import create_flags, FLAGS -from .util.helpers import check_ctcdecoder_version, ExceptionBox -from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn -from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote +from .util.feeding import audio_to_features, audiofile_to_features, create_dataset +from .util.flags import FLAGS, create_flags +from .util.helpers import ExceptionBox, check_ctcdecoder_version +from .util.io import ( + is_remote_path, + isdir_remote, + listdir_remote, + open_remote, + remove_remote, +) +from .util.logging import ( + create_progressbar, + log_debug, + log_error, + log_info, + log_progress, + log_warn, +) check_ctcdecoder_version() # Graph Creation # ============== + def variable_on_cpu(name, shape, initializer): r""" Next we concern ourselves with graph creation. @@ -64,11 +90,15 @@ def create_overlapping_windows(batch_x): # Create a constant convolution filter using an identity matrix, so that the # convolution returns patches of the input tensor as is, and we can create # overlapping windows over the MFCCs. - eye_filter = tf.constant(np.eye(window_width * num_channels) - .reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation + eye_filter = tf.constant( + np.eye(window_width * num_channels).reshape( + window_width, num_channels, window_width * num_channels + ), + tf.float32, + ) # pylint: disable=bad-continuation # Create overlapping windows - batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME') + batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding="SAME") # Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input] batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels]) @@ -78,8 +108,14 @@ def create_overlapping_windows(batch_x): def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False): with tfv1.variable_scope(name): - bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer()) - weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")) + bias = variable_on_cpu("bias", [units], tfv1.zeros_initializer()) + weights = variable_on_cpu( + "weights", + [x.shape[-1], units], + tfv1.keras.initializers.VarianceScaling( + scale=1.0, mode="fan_avg", distribution="uniform" + ), + ) output = tf.nn.bias_add(tf.matmul(x, weights), bias) @@ -97,22 +133,28 @@ def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False): def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse): - with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'): - fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim, - forget_bias=0, - reuse=reuse, - name='cudnn_compatible_lstm_cell') + with tfv1.variable_scope("cudnn_lstm/rnn/multi_rnn_cell/cell_0"): + fw_cell = tf.contrib.rnn.LSTMBlockFusedCell( + Config.n_cell_dim, + forget_bias=0, + reuse=reuse, + name="cudnn_compatible_lstm_cell", + ) - output, output_state = fw_cell(inputs=x, - dtype=tf.float32, - sequence_length=seq_length, - initial_state=previous_state) + output, output_state = fw_cell( + inputs=x, + dtype=tf.float32, + sequence_length=seq_length, + initial_state=previous_state, + ) return output, output_state def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _): - assert previous_state is None # 'Passing previous state not supported with CuDNN backend' + assert ( + previous_state is None + ) # 'Passing previous state not supported with CuDNN backend' # Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate # the object it creates the variables, and then you just call it several times @@ -123,45 +165,62 @@ def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _): # emulating a static function variable. if not rnn_impl_cudnn_rnn.cell: # Forward direction cell: - fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1, - num_units=Config.n_cell_dim, - input_mode='linear_input', - direction='unidirectional', - dtype=tf.float32) + fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM( + num_layers=1, + num_units=Config.n_cell_dim, + input_mode="linear_input", + direction="unidirectional", + dtype=tf.float32, + ) rnn_impl_cudnn_rnn.cell = fw_cell - output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x, - sequence_lengths=seq_length) + output, output_state = rnn_impl_cudnn_rnn.cell( + inputs=x, sequence_lengths=seq_length + ) return output, output_state + rnn_impl_cudnn_rnn.cell = None def rnn_impl_static_rnn(x, seq_length, previous_state, reuse): - with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'): + with tfv1.variable_scope("cudnn_lstm/rnn/multi_rnn_cell"): # Forward direction cell: - fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim, - forget_bias=0, - reuse=reuse, - name='cudnn_compatible_lstm_cell') + fw_cell = tfv1.nn.rnn_cell.LSTMCell( + Config.n_cell_dim, + forget_bias=0, + reuse=reuse, + name="cudnn_compatible_lstm_cell", + ) # Split rank N tensor into list of rank N-1 tensors x = [x[l] for l in range(x.shape[0])] - output, output_state = tfv1.nn.static_rnn(cell=fw_cell, - inputs=x, - sequence_length=seq_length, - initial_state=previous_state, - dtype=tf.float32, - scope='cell_0') + output, output_state = tfv1.nn.static_rnn( + cell=fw_cell, + inputs=x, + sequence_length=seq_length, + initial_state=previous_state, + dtype=tf.float32, + scope="cell_0", + ) output = tf.concat(output, 0) return output, output_state -def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell): +def create_model( + batch_x, + seq_length, + dropout, + reuse=False, + batch_size=None, + previous_state=None, + overlap=True, + rnn_impl=rnn_impl_lstmblockfusedcell, +): layers = {} # Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context] @@ -178,14 +237,34 @@ def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, pre # Permute n_steps and batch_size batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3]) # Reshape to prepare input for first layer - batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) - layers['input_reshaped'] = batch_x + batch_x = tf.reshape( + batch_x, [-1, Config.n_input + 2 * Config.n_input * Config.n_context] + ) # (n_steps*batch_size, n_input + 2*n_input*n_context) + layers["input_reshaped"] = batch_x # The next three blocks will pass `batch_x` through three hidden layers with # clipped RELU activation and dropout. - layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0], layer_norm=FLAGS.layer_norm) - layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1], layer_norm=FLAGS.layer_norm) - layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2], layer_norm=FLAGS.layer_norm) + layers["layer_1"] = layer_1 = dense( + "layer_1", + batch_x, + Config.n_hidden_1, + dropout_rate=dropout[0], + layer_norm=FLAGS.layer_norm, + ) + layers["layer_2"] = layer_2 = dense( + "layer_2", + layer_1, + Config.n_hidden_2, + dropout_rate=dropout[1], + layer_norm=FLAGS.layer_norm, + ) + layers["layer_3"] = layer_3 = dense( + "layer_3", + layer_2, + Config.n_hidden_3, + dropout_rate=dropout[2], + layer_norm=FLAGS.layer_norm, + ) # `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`, # as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`. @@ -198,20 +277,30 @@ def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, pre # Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim] # to a tensor of shape [n_steps*batch_size, n_cell_dim] output = tf.reshape(output, [-1, Config.n_cell_dim]) - layers['rnn_output'] = output - layers['rnn_output_state'] = output_state + layers["rnn_output"] = output + layers["rnn_output_state"] = output_state # Now we feed `output` to the fifth hidden layer with clipped RELU activation - layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5], layer_norm=FLAGS.layer_norm) + layers["layer_5"] = layer_5 = dense( + "layer_5", + output, + Config.n_hidden_5, + dropout_rate=dropout[5], + layer_norm=FLAGS.layer_norm, + ) # Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits. - layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False) + layers["layer_6"] = layer_6 = dense( + "layer_6", layer_5, Config.n_hidden_6, relu=False + ) # Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6] # to the slightly more useful shape [n_steps, batch_size, n_hidden_6]. # Note, that this differs from the input in that it is time-major. - layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits') - layers['raw_logits'] = layer_6 + layer_6 = tf.reshape( + layer_6, [-1, batch_size, Config.n_hidden_6], name="raw_logits" + ) + layers["raw_logits"] = layer_6 # Output shape: [n_steps, batch_size, n_hidden_6] return layer_6, layers @@ -227,12 +316,13 @@ def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, pre # Conveniently, this loss function is implemented in TensorFlow. # Thus, we can simply make use of this implementation to define our loss. + def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): - r''' + r""" This routine beam search decodes a mini-batch and calculates the loss and mean edit distance. Next to total and average loss it returns the mean edit distance, the decoded result and the batch's original Y. - ''' + """ # Obtain the next batch of data batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next() @@ -242,13 +332,19 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): rnn_impl = rnn_impl_lstmblockfusedcell # Calculate the logits of the batch - logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl) + logits, _ = create_model( + batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl + ) # Compute the CTC loss using TensorFlow's `ctc_loss` - total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len) + total_loss = tfv1.nn.ctc_loss( + labels=batch_y, inputs=logits, sequence_length=batch_seq_len + ) # Check if any files lead to non finite loss - non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss))) + non_finite_files = tf.gather( + batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)) + ) # Calculate the average loss across the batch avg_loss = tf.reduce_mean(input_tensor=total_loss) @@ -267,10 +363,12 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): # we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980), # because, generally, it requires less fine-tuning. def create_optimizer(learning_rate_var): - optimizer = tfv1.train.AdamOptimizer(learning_rate=learning_rate_var, - beta1=FLAGS.beta1, - beta2=FLAGS.beta2, - epsilon=FLAGS.epsilon) + optimizer = tfv1.train.AdamOptimizer( + learning_rate=learning_rate_var, + beta1=FLAGS.beta1, + beta2=FLAGS.beta2, + epsilon=FLAGS.epsilon, + ) return optimizer @@ -290,12 +388,13 @@ def create_optimizer(learning_rate_var): # on which all operations within the tower execute. # For example, all operations of 'tower 0' could execute on the first GPU `tf.device('/gpu:0')`. + def get_tower_results(iterator, optimizer, dropout_rates): - r''' + r""" With this preliminary step out of the way, we can for each GPU introduce a tower for which's batch we calculate and return the optimization gradients and the average loss across towers. - ''' + """ # To calculate the mean of the losses tower_avg_losses = [] @@ -312,10 +411,12 @@ def get_tower_results(iterator, optimizer, dropout_rates): device = Config.available_devices[i] with tf.device(device): # Create a scope for all operations of tower i - with tf.name_scope('tower_%d' % i): + with tf.name_scope("tower_%d" % i): # Calculate the avg_loss and mean_edit_distance and retrieve the decoded # batch along with the original batch's labels (Y) of this tower - avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0) + avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss( + iterator, dropout_rates, reuse=i > 0 + ) # Allow for variables to be re-used by the next tower tfv1.get_variable_scope().reuse_variables() @@ -332,7 +433,9 @@ def get_tower_results(iterator, optimizer, dropout_rates): tower_non_finite_files.append(non_finite_files) avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0) - tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries']) + tfv1.summary.scalar( + name="step_loss", tensor=avg_loss_across_towers, collections=["step_summaries"] + ) all_non_finite_files = tf.concat(tower_non_finite_files, axis=0) @@ -341,11 +444,11 @@ def get_tower_results(iterator, optimizer, dropout_rates): def average_gradients(tower_gradients): - r''' + r""" A routine for computing each variable's average of the gradients obtained from the GPUs. Note also that this code acts as a synchronization point as it requires all GPUs to be finished with their mini-batch before it can run to completion. - ''' + """ # List of average gradients to return to the caller average_grads = [] @@ -377,22 +480,29 @@ def average_gradients(tower_gradients): return average_grads - # Logging # ======= + def log_variable(variable, gradient=None): - r''' + r""" We introduce a function for logging a tensor variable's current state. It logs scalar values for the mean, standard deviation, minimum and maximum. Furthermore it logs a histogram of its state and (if given) of an optimization gradient. - ''' - name = variable.name.replace(':', '_') + """ + name = variable.name.replace(":", "_") mean = tf.reduce_mean(input_tensor=variable) - tfv1.summary.scalar(name='%s/mean' % name, tensor=mean) - tfv1.summary.scalar(name='%s/sttdev' % name, tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean)))) - tfv1.summary.scalar(name='%s/max' % name, tensor=tf.reduce_max(input_tensor=variable)) - tfv1.summary.scalar(name='%s/min' % name, tensor=tf.reduce_min(input_tensor=variable)) + tfv1.summary.scalar(name="%s/mean" % name, tensor=mean) + tfv1.summary.scalar( + name="%s/sttdev" % name, + tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))), + ) + tfv1.summary.scalar( + name="%s/max" % name, tensor=tf.reduce_max(input_tensor=variable) + ) + tfv1.summary.scalar( + name="%s/min" % name, tensor=tf.reduce_min(input_tensor=variable) + ) tfv1.summary.histogram(name=name, values=variable) if gradient is not None: if isinstance(gradient, tf.IndexedSlices): @@ -400,13 +510,13 @@ def log_variable(variable, gradient=None): else: grad_values = gradient if grad_values is not None: - tfv1.summary.histogram(name='%s/gradients' % name, values=grad_values) + tfv1.summary.histogram(name="%s/gradients" % name, values=grad_values) def log_grads_and_vars(grads_and_vars): - r''' + r""" Let's also introduce a helper function for logging collections of gradient/variable tuples. - ''' + """ for gradient, variable in grads_and_vars: log_variable(variable, gradient=gradient) @@ -415,53 +525,71 @@ def train(): exception_box = ExceptionBox() # Create training and validation datasets - train_set = create_dataset(FLAGS.train_files.split(','), - batch_size=FLAGS.train_batch_size, - epochs=FLAGS.epochs, - augmentations=Config.augmentations, - cache_path=FLAGS.feature_cache, - train_phase=True, - exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, - reverse=FLAGS.reverse_train, - limit=FLAGS.limit_train, - buffering=FLAGS.read_buffer) + train_set = create_dataset( + FLAGS.train_files.split(","), + batch_size=FLAGS.train_batch_size, + epochs=FLAGS.epochs, + augmentations=Config.augmentations, + cache_path=FLAGS.feature_cache, + train_phase=True, + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, + reverse=FLAGS.reverse_train, + limit=FLAGS.limit_train, + buffering=FLAGS.read_buffer, + ) - iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), - tfv1.data.get_output_shapes(train_set), - output_classes=tfv1.data.get_output_classes(train_set)) + iterator = tfv1.data.Iterator.from_structure( + tfv1.data.get_output_types(train_set), + tfv1.data.get_output_shapes(train_set), + output_classes=tfv1.data.get_output_classes(train_set), + ) # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) if FLAGS.dev_files: - dev_sources = FLAGS.dev_files.split(',') - dev_sets = [create_dataset([source], - batch_size=FLAGS.dev_batch_size, - train_phase=False, - augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], - exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, - reverse=FLAGS.reverse_dev, - limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer) for source in dev_sources] + dev_sources = FLAGS.dev_files.split(",") + dev_sets = [ + create_dataset( + [source], + batch_size=FLAGS.dev_batch_size, + train_phase=False, + augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + reverse=FLAGS.reverse_dev, + limit=FLAGS.limit_dev, + buffering=FLAGS.read_buffer, + ) + for source in dev_sources + ] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] if FLAGS.metrics_files: - metrics_sources = FLAGS.metrics_files.split(',') - metrics_sets = [create_dataset([source], - batch_size=FLAGS.dev_batch_size, - train_phase=False, - augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], - exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, - reverse=FLAGS.reverse_dev, - limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer) for source in metrics_sources] - metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] + metrics_sources = FLAGS.metrics_files.split(",") + metrics_sets = [ + create_dataset( + [source], + batch_size=FLAGS.dev_batch_size, + train_phase=False, + augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], + exception_box=exception_box, + process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + reverse=FLAGS.reverse_dev, + limit=FLAGS.limit_dev, + buffering=FLAGS.read_buffer, + ) + for source in metrics_sources + ] + metrics_init_ops = [ + iterator.make_initializer(metrics_set) for metrics_set in metrics_sets + ] # Dropout - dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] + dropout_rates = [ + tfv1.placeholder(tf.float32, name="dropout_{}".format(i)) for i in range(6) + ] dropout_feed_dict = { dropout_rates[0]: FLAGS.dropout_rate, dropout_rates[1]: FLAGS.dropout_rate2, @@ -470,21 +598,27 @@ def train(): dropout_rates[4]: FLAGS.dropout_rate5, dropout_rates[5]: FLAGS.dropout_rate6, } - no_dropout_feed_dict = { - rate: 0. for rate in dropout_rates - } + no_dropout_feed_dict = {rate: 0.0 for rate in dropout_rates} # Building the graph - learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) - reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) + learning_rate_var = tfv1.get_variable( + "learning_rate", initializer=FLAGS.learning_rate, trainable=False + ) + reduce_learning_rate_op = learning_rate_var.assign( + tf.multiply(learning_rate_var, FLAGS.plateau_reduction) + ) optimizer = create_optimizer(learning_rate_var) # Enable mixed precision training if FLAGS.automatic_mixed_precision: - log_info('Enabling automatic mixed precision training.') - optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) + log_info("Enabling automatic mixed precision training.") + optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite( + optimizer + ) - gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) + gradients, loss, non_finite_files = get_tower_results( + iterator, optimizer, dropout_rates + ) # Average tower gradients across GPUs avg_tower_gradients = average_gradients(gradients) @@ -492,38 +626,46 @@ def train(): # global_step is automagically incremented by the optimizer global_step = tfv1.train.get_or_create_global_step() - apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) + apply_gradient_op = optimizer.apply_gradients( + avg_tower_gradients, global_step=global_step + ) # Summaries - step_summaries_op = tfv1.summary.merge_all('step_summaries') + step_summaries_op = tfv1.summary.merge_all("step_summaries") step_summary_writers = { - 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), - 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120), - 'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120), + "train": tfv1.summary.FileWriter( + os.path.join(FLAGS.summary_dir, "train"), max_queue=120 + ), + "dev": tfv1.summary.FileWriter( + os.path.join(FLAGS.summary_dir, "dev"), max_queue=120 + ), + "metrics": tfv1.summary.FileWriter( + os.path.join(FLAGS.summary_dir, "metrics"), max_queue=120 + ), } human_readable_set_names = { - 'train': 'Training', - 'dev': 'Validation', - 'metrics': 'Metrics', + "train": "Training", + "dev": "Validation", + "metrics": "Metrics", } # Checkpointing checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) - checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') + checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, "train") best_dev_saver = tfv1.train.Saver(max_to_keep=1) - best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') + best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, "best_dev") # Save flags next to checkpoints if not is_remote_path(FLAGS.save_checkpoint_dir): os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) - flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') - with open_remote(flags_file, 'w') as fout: + flags_file = os.path.join(FLAGS.save_checkpoint_dir, "flags.txt") + with open_remote(flags_file, "w") as fout: fout.write(FLAGS.flags_into_string()) with tfv1.Session(config=Config.session_config) as session: - log_debug('Session opened.') + log_debug("Session opened.") # Prevent further graph changes tfv1.get_default_graph().finalize() @@ -532,7 +674,7 @@ def train(): load_or_init_graph_for_training(session) def run_set(set_name, epoch, init_op, dataset=None): - is_train = set_name == 'train' + is_train = set_name == "train" train_op = apply_gradient_op if is_train else [] feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict @@ -543,26 +685,43 @@ def train(): checkpoint_time = time.time() if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache: - feature_cache_index = FLAGS.feature_cache + '.index' - if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index): - log_info('Invalidating feature cache') - remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files + feature_cache_index = FLAGS.feature_cache + ".index" + if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile( + feature_cache_index + ): + log_info("Invalidating feature cache") + remove_remote( + feature_cache_index + ) # this will let TF also overwrite the related cache data files # Setup progress bar class LossWidget(progressbar.widgets.FormatLabel): def __init__(self): - progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') + progressbar.widgets.FormatLabel.__init__( + self, format="Loss: %(mean_loss)f" + ) def __call__(self, progress, data, **kwargs): - data['mean_loss'] = total_loss / step_count if step_count else 0.0 - return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) + data["mean_loss"] = total_loss / step_count if step_count else 0.0 + return progressbar.widgets.FormatLabel.__call__( + self, progress, data, **kwargs + ) - prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name]) - widgets = [' | ', progressbar.widgets.Timer(), - ' | Steps: ', progressbar.widgets.Counter(), - ' | ', LossWidget()] - suffix = ' | Dataset: {}'.format(dataset) if dataset else None - pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() + prefix = "Epoch {} | {:>10}".format( + epoch, human_readable_set_names[set_name] + ) + widgets = [ + " | ", + progressbar.widgets.Timer(), + " | Steps: ", + progressbar.widgets.Counter(), + " | ", + LossWidget(), + ] + suffix = " | Dataset: {}".format(dataset) if dataset else None + pbar = create_progressbar( + prefix=prefix, widgets=widgets, suffix=suffix + ).start() # Initialize iterator to the appropriate dataset session.run(init_op) @@ -570,18 +729,33 @@ def train(): # Batch loop while True: try: - _, current_step, batch_loss, problem_files, step_summary = \ - session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], - feed_dict=feed_dict) + ( + _, + current_step, + batch_loss, + problem_files, + step_summary, + ) = session.run( + [ + train_op, + global_step, + loss, + non_finite_files, + step_summaries_op, + ], + feed_dict=feed_dict, + ) exception_box.raise_if_set() except tf.errors.OutOfRangeError: exception_box.raise_if_set() break if problem_files.size > 0: - problem_files = [f.decode('utf8') for f in problem_files[..., 0]] - log_error('The following files caused an infinite (or NaN) ' - 'loss: {}'.format(','.join(problem_files))) + problem_files = [f.decode("utf8") for f in problem_files[..., 0]] + log_error( + "The following files caused an infinite (or NaN) " + "loss: {}".format(",".join(problem_files)) + ) total_loss += batch_loss step_count += 1 @@ -590,25 +764,33 @@ def train(): step_summary_writer.add_summary(step_summary, current_step) - if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: - checkpoint_saver.save(session, checkpoint_path, global_step=current_step) + if ( + is_train + and FLAGS.checkpoint_secs > 0 + and time.time() - checkpoint_time > FLAGS.checkpoint_secs + ): + checkpoint_saver.save( + session, checkpoint_path, global_step=current_step + ) checkpoint_time = time.time() pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count - log_info('STARTING Optimization') + log_info("STARTING Optimization") train_start_time = datetime.utcnow() - best_dev_loss = float('inf') + best_dev_loss = float("inf") dev_losses = [] epochs_without_improvement = 0 try: for epoch in range(FLAGS.epochs): # Training - log_progress('Training epoch %d...' % epoch) - train_loss, _ = run_set('train', epoch, train_init_op) - log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) + log_progress("Training epoch %d..." % epoch) + train_loss, _ = run_set("train", epoch, train_init_op) + log_progress( + "Finished training epoch %d - loss: %f" % (epoch, train_loss) + ) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: @@ -616,11 +798,14 @@ def train(): dev_loss = 0.0 total_steps = 0 for source, init_op in zip(dev_sources, dev_init_ops): - log_progress('Validating epoch %d on %s...' % (epoch, source)) - set_loss, steps = run_set('dev', epoch, init_op, dataset=source) + log_progress("Validating epoch %d on %s..." % (epoch, source)) + set_loss, steps = run_set("dev", epoch, init_op, dataset=source) dev_loss += set_loss * steps total_steps += steps - log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) + log_progress( + "Finished validating epoch %d on %s - loss: %f" + % (epoch, source, set_loss) + ) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) @@ -635,13 +820,27 @@ def train(): # Save new best model if dev_loss < best_dev_loss: best_dev_loss = dev_loss - save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') - log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) + save_path = best_dev_saver.save( + session, + best_dev_path, + global_step=global_step, + latest_filename="best_dev_checkpoint", + ) + log_info( + "Saved new best validating model with loss %f to: %s" + % (best_dev_loss, save_path) + ) # Early stopping - if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: - log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( - epochs_without_improvement)) + if ( + FLAGS.early_stop + and epochs_without_improvement == FLAGS.es_epochs + ): + log_info( + "Early stop triggered as the loss did not improve the last {} epochs".format( + epochs_without_improvement + ) + ) break # Reduce learning rate on plateau @@ -658,31 +857,46 @@ def train(): # Reduce learning rate session.run(reduce_learning_rate_op) current_learning_rate = learning_rate_var.eval() - log_info('Encountered a plateau, reducing learning rate to {}'.format( - current_learning_rate)) + log_info( + "Encountered a plateau, reducing learning rate to {}".format( + current_learning_rate + ) + ) # Overwrite best checkpoint with new learning rate value - save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') - log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) + save_path = best_dev_saver.save( + session, + best_dev_path, + global_step=global_step, + latest_filename="best_dev_checkpoint", + ) + log_info( + "Saved best validating model with reduced learning rate to: %s" + % (save_path) + ) if FLAGS.metrics_files: # Read only metrics, not affecting best validation loss tracking for source, init_op in zip(metrics_sources, metrics_init_ops): - log_progress('Metrics for epoch %d on %s...' % (epoch, source)) - set_loss, _ = run_set('metrics', epoch, init_op, dataset=source) - log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) - - print('-' * 80) + log_progress("Metrics for epoch %d on %s..." % (epoch, source)) + set_loss, _ = run_set("metrics", epoch, init_op, dataset=source) + log_progress( + "Metrics for epoch %d on %s - loss: %f" + % (epoch, source, set_loss) + ) + print("-" * 80) except KeyboardInterrupt: pass - log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) - log_debug('Session closed.') + log_info( + "FINISHED optimization in {}".format(datetime.utcnow() - train_start_time) + ) + log_debug("Session closed.") def test(): - samples = evaluate(FLAGS.test_files.split(','), create_model) + samples = evaluate(FLAGS.test_files.split(","), create_model) if FLAGS.test_output_file: save_samples_json(samples, FLAGS.test_output_file) @@ -691,26 +905,43 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): batch_size = batch_size if batch_size > 0 else None # Create feature computation graph - input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples') + input_samples = tfv1.placeholder( + tf.float32, [Config.audio_window_samples], "input_samples" + ) samples = tf.expand_dims(input_samples, -1) mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate) - mfccs = tf.identity(mfccs, name='mfccs') + mfccs = tf.identity(mfccs, name="mfccs") # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] # This shape is read by the native_client in STT_CreateModel to know the # value of n_steps, n_context and n_input. Make sure you update the code # there if this shape is changed. - input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node') - seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths') + input_tensor = tfv1.placeholder( + tf.float32, + [ + batch_size, + n_steps if n_steps > 0 else None, + 2 * Config.n_context + 1, + Config.n_input, + ], + name="input_node", + ) + seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths") if batch_size <= 0: # no state management since n_step is expected to be dynamic too (see below) previous_state = None else: - previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') - previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') + previous_state_c = tfv1.placeholder( + tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c" + ) + previous_state_h = tfv1.placeholder( + tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h" + ) - previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h) + previous_state = tf.nn.rnn_cell.LSTMStateTuple( + previous_state_c, previous_state_h + ) # One rate per layer no_dropout = [None] * 6 @@ -720,13 +951,15 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): else: rnn_impl = rnn_impl_lstmblockfusedcell - logits, layers = create_model(batch_x=input_tensor, - batch_size=batch_size, - seq_length=seq_length if not FLAGS.export_tflite else None, - dropout=no_dropout, - previous_state=previous_state, - overlap=False, - rnn_impl=rnn_impl) + logits, layers = create_model( + batch_x=input_tensor, + batch_size=batch_size, + seq_length=seq_length if not FLAGS.export_tflite else None, + dropout=no_dropout, + previous_state=previous_state, + overlap=False, + rnn_impl=rnn_impl, + ) # TF Lite runtime will check that input dimensions are 1, 2 or 4 # by default we get 3, the middle one being batch_size which is forced to @@ -735,47 +968,50 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): logits = tf.squeeze(logits, [1]) # Apply softmax for CTC decoder - probs = tf.nn.softmax(logits, name='logits') + probs = tf.nn.softmax(logits, name="logits") if batch_size <= 0: if tflite: - raise NotImplementedError('dynamic batch_size does not support tflite nor streaming') + raise NotImplementedError( + "dynamic batch_size does not support tflite nor streaming" + ) if n_steps > 0: - raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too') + raise NotImplementedError( + "dynamic batch_size expect n_steps to be dynamic too" + ) return ( { - 'input': input_tensor, - 'input_lengths': seq_length, + "input": input_tensor, + "input_lengths": seq_length, }, { - 'outputs': probs, + "outputs": probs, }, - layers + layers, ) - new_state_c, new_state_h = layers['rnn_output_state'] - new_state_c = tf.identity(new_state_c, name='new_state_c') - new_state_h = tf.identity(new_state_h, name='new_state_h') + new_state_c, new_state_h = layers["rnn_output_state"] + new_state_c = tf.identity(new_state_c, name="new_state_c") + new_state_h = tf.identity(new_state_h, name="new_state_h") inputs = { - 'input': input_tensor, - 'previous_state_c': previous_state_c, - 'previous_state_h': previous_state_h, - 'input_samples': input_samples, + "input": input_tensor, + "previous_state_c": previous_state_c, + "previous_state_h": previous_state_h, + "input_samples": input_samples, } if not FLAGS.export_tflite: - inputs['input_lengths'] = seq_length + inputs["input_lengths"] = seq_length outputs = { - 'outputs': probs, - 'new_state_c': new_state_c, - 'new_state_h': new_state_h, - 'mfccs': mfccs, - + "outputs": probs, + "new_state_c": new_state_c, + "new_state_h": new_state_h, + "mfccs": mfccs, # Expose internal layers for downstream applications - 'layer_3': layers['layer_3'], - 'layer_5': layers['layer_5'] + "layer_3": layers["layer_3"], + "layer_5": layers["layer_5"], } return inputs, outputs, layers @@ -786,41 +1022,61 @@ def file_relative_read(fname): def export(): - r''' + r""" Restores the trained variables into a simpler graph that will be exported for serving. - ''' - log_info('Exporting the model...') + """ + log_info("Exporting the model...") - inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite) + inputs, outputs, _ = create_inference_graph( + batch_size=FLAGS.export_batch_size, + n_steps=FLAGS.n_steps, + tflite=FLAGS.export_tflite, + ) - graph_version = int(file_relative_read('GRAPH_VERSION').strip()) + graph_version = int(file_relative_read("GRAPH_VERSION").strip()) assert graph_version > 0 - outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') - outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') - outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') - outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') - outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') - outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet') + outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version") + outputs["metadata_sample_rate"] = tf.constant( + [FLAGS.audio_sample_rate], name="metadata_sample_rate" + ) + outputs["metadata_feature_win_len"] = tf.constant( + [FLAGS.feature_win_len], name="metadata_feature_win_len" + ) + outputs["metadata_feature_win_step"] = tf.constant( + [FLAGS.feature_win_step], name="metadata_feature_win_step" + ) + outputs["metadata_beam_width"] = tf.constant( + [FLAGS.export_beam_width], name="metadata_beam_width" + ) + outputs["metadata_alphabet"] = tf.constant( + [Config.alphabet.Serialize()], name="metadata_alphabet" + ) if FLAGS.export_language: - outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language') + outputs["metadata_language"] = tf.constant( + [FLAGS.export_language.encode("utf-8")], name="metadata_language" + ) # Prevent further graph changes tfv1.get_default_graph().finalize() - output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)] - output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)] + output_names_tensors = [ + tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor) + ] + output_names_ops = [ + op.name for op in outputs.values() if isinstance(op, tf.Operation) + ] output_names = output_names_tensors + output_names_ops with tf.Session() as session: # Restore variables from checkpoint load_graph_for_evaluation(session) - output_filename = FLAGS.export_file_name + '.pb' + output_filename = FLAGS.export_file_name + ".pb" if FLAGS.remove_export: if isdir_remote(FLAGS.export_dir): - log_info('Removing old export') + log_info("Removing old export") remove_remote(FLAGS.export_dir) output_graph_path = os.path.join(FLAGS.export_dir, output_filename) @@ -831,67 +1087,86 @@ def export(): frozen_graph = tfv1.graph_util.convert_variables_to_constants( sess=session, input_graph_def=tfv1.get_default_graph().as_graph_def(), - output_node_names=output_names) + output_node_names=output_names, + ) frozen_graph = tfv1.graph_util.extract_sub_graph( - graph_def=frozen_graph, - dest_nodes=output_names) + graph_def=frozen_graph, dest_nodes=output_names + ) if not FLAGS.export_tflite: - with open_remote(output_graph_path, 'wb') as fout: + with open_remote(output_graph_path, "wb") as fout: fout.write(frozen_graph.SerializeToString()) else: - output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) + output_tflite_path = os.path.join( + FLAGS.export_dir, output_filename.replace(".pb", ".tflite") + ) - converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) + converter = tf.lite.TFLiteConverter( + frozen_graph, + input_tensors=inputs.values(), + output_tensors=outputs.values(), + ) converter.optimizations = [tf.lite.Optimize.DEFAULT] # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite converter.allow_custom_ops = True tflite_model = converter.convert() - with open_remote(output_tflite_path, 'wb') as fout: + with open_remote(output_tflite_path, "wb") as fout: fout.write(tflite_model) - log_info('Models exported at %s' % (FLAGS.export_dir)) + log_info("Models exported at %s" % (FLAGS.export_dir)) - metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format( - FLAGS.export_author_id, - FLAGS.export_model_name, - FLAGS.export_model_version)) + metadata_fname = os.path.join( + FLAGS.export_dir, + "{}_{}_{}.md".format( + FLAGS.export_author_id, FLAGS.export_model_name, FLAGS.export_model_version + ), + ) - model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow' - with open_remote(metadata_fname, 'w') as f: - f.write('---\n') - f.write('author: {}\n'.format(FLAGS.export_author_id)) - f.write('model_name: {}\n'.format(FLAGS.export_model_name)) - f.write('model_version: {}\n'.format(FLAGS.export_model_version)) - f.write('contact_info: {}\n'.format(FLAGS.export_contact_info)) - f.write('license: {}\n'.format(FLAGS.export_license)) - f.write('language: {}\n'.format(FLAGS.export_language)) - f.write('runtime: {}\n'.format(model_runtime)) - f.write('min_stt_version: {}\n'.format(FLAGS.export_min_stt_version)) - f.write('max_stt_version: {}\n'.format(FLAGS.export_max_stt_version)) - f.write('acoustic_model_url: \n') - f.write('scorer_url: \n') - f.write('---\n') - f.write('{}\n'.format(FLAGS.export_description)) + model_runtime = "tflite" if FLAGS.export_tflite else "tensorflow" + with open_remote(metadata_fname, "w") as f: + f.write("---\n") + f.write("author: {}\n".format(FLAGS.export_author_id)) + f.write("model_name: {}\n".format(FLAGS.export_model_name)) + f.write("model_version: {}\n".format(FLAGS.export_model_version)) + f.write("contact_info: {}\n".format(FLAGS.export_contact_info)) + f.write("license: {}\n".format(FLAGS.export_license)) + f.write("language: {}\n".format(FLAGS.export_language)) + f.write("runtime: {}\n".format(model_runtime)) + f.write("min_stt_version: {}\n".format(FLAGS.export_min_stt_version)) + f.write("max_stt_version: {}\n".format(FLAGS.export_max_stt_version)) + f.write( + "acoustic_model_url: \n" + ) + f.write( + "scorer_url: \n" + ) + f.write("---\n") + f.write("{}\n".format(FLAGS.export_description)) - log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname)) + log_info( + "Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.".format( + metadata_fname + ) + ) def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip - export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/' + export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), "") # Force ending '/' if is_remote_path(export_dir): - log_error("Cannot package remote path zip %s. Please do this manually." % export_dir) + log_error( + "Cannot package remote path zip %s. Please do this manually." % export_dir + ) return zip_filename = os.path.dirname(export_dir) - + shutil.copy(FLAGS.scorer_path, export_dir) - archive = shutil.make_archive(zip_filename, 'zip', export_dir) - log_info('Exported packaged model {}'.format(archive)) + archive = shutil.make_archive(zip_filename, "zip", export_dir) + log_info("Exported packaged model {}".format(archive)) def do_single_file_inference(input_file_path): @@ -913,23 +1188,32 @@ def do_single_file_inference(input_file_path): features = create_overlapping_windows(features).eval(session=session) features_len = features_len.eval(session=session) - probs = outputs['outputs'].eval(feed_dict={ - inputs['input']: features, - inputs['input_lengths']: features_len, - inputs['previous_state_c']: previous_state_c, - inputs['previous_state_h']: previous_state_h, - }, session=session) + probs = outputs["outputs"].eval( + feed_dict={ + inputs["input"]: features, + inputs["input_lengths"]: features_len, + inputs["previous_state_c"]: previous_state_c, + inputs["previous_state_h"]: previous_state_h, + }, + session=session, + ) probs = np.squeeze(probs) if FLAGS.scorer_path: - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, - FLAGS.scorer_path, Config.alphabet) + scorer = Scorer( + FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet + ) else: scorer = None - decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width, - scorer=scorer, cutoff_prob=FLAGS.cutoff_prob, - cutoff_top_n=FLAGS.cutoff_top_n) + decoded = ctc_beam_search_decoder( + probs, + Config.alphabet, + FLAGS.beam_width, + scorer=scorer, + cutoff_prob=FLAGS.cutoff_prob, + cutoff_top_n=FLAGS.cutoff_top_n, + ) # Print highest probability result print(decoded[0][1]) @@ -937,18 +1221,25 @@ def do_single_file_inference(input_file_path): def early_training_checks(): # Check for proper scorer early if FLAGS.scorer_path: - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, - FLAGS.scorer_path, Config.alphabet) + scorer = Scorer( + FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet + ) del scorer - if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir: - log_warn('WARNING: You specified different values for --load_checkpoint_dir ' - 'and --save_checkpoint_dir, but you are running training and testing ' - 'in a single invocation. The testing step will respect --load_checkpoint_dir, ' - 'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. ' - 'Train and test in two separate invocations, specifying the correct ' - '--load_checkpoint_dir in both cases, or use the same location ' - 'for loading and saving.') + if ( + FLAGS.train_files + and FLAGS.test_files + and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir + ): + log_warn( + "WARNING: You specified different values for --load_checkpoint_dir " + "and --save_checkpoint_dir, but you are running training and testing " + "in a single invocation. The testing step will respect --load_checkpoint_dir, " + "and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. " + "Train and test in two separate invocations, specifying the correct " + "--load_checkpoint_dir in both cases, or use the same location " + "for loading and saving." + ) def main(_): @@ -973,7 +1264,9 @@ def main(_): FLAGS.export_tflite = True if listdir_remote(FLAGS.export_dir): - log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) + log_error( + "Directory {} is not empty, please fix this.".format(FLAGS.export_dir) + ) sys.exit(1) export() @@ -988,5 +1281,6 @@ def run_script(): create_flags() absl.app.run(main) -if __name__ == '__main__': + +if __name__ == "__main__": run_script() diff --git a/training/coqui_stt_training/util/audio.py b/training/coqui_stt_training/util/audio.py index 444fbe9d..ab96c606 100644 --- a/training/coqui_stt_training/util/audio.py +++ b/training/coqui_stt_training/util/audio.py @@ -2,28 +2,29 @@ import collections import ctypes import io import math -import numpy as np import os -import pyogg import tempfile import wave +from collections import namedtuple + +import numpy as np +import pyogg from .helpers import LimitingPool -from collections import namedtuple -from .io import open_remote, remove_remote, copy_remote, is_remote_path +from .io import copy_remote, is_remote_path, open_remote, remove_remote -AudioFormat = namedtuple('AudioFormat', 'rate channels width') +AudioFormat = namedtuple("AudioFormat", "rate channels width") DEFAULT_RATE = 16000 DEFAULT_CHANNELS = 1 DEFAULT_WIDTH = 2 DEFAULT_FORMAT = AudioFormat(DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH) -AUDIO_TYPE_NP = 'application/vnd.mozilla.np' -AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm' -AUDIO_TYPE_WAV = 'audio/wav' -AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus' -AUDIO_TYPE_OGG_OPUS = 'application/vnd.deepspeech.ogg_opus' +AUDIO_TYPE_NP = "application/vnd.mozilla.np" +AUDIO_TYPE_PCM = "application/vnd.mozilla.pcm" +AUDIO_TYPE_WAV = "audio/wav" +AUDIO_TYPE_OPUS = "application/vnd.mozilla.opus" +AUDIO_TYPE_OGG_OPUS = "application/vnd.deepspeech.ogg_opus" SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, AUDIO_TYPE_OGG_OPUS] @@ -49,6 +50,7 @@ class Sample: duration : float Audio duration of the sample in seconds """ + def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None): """ Parameters @@ -74,20 +76,26 @@ class Sample: self.audio_format = audio_format self.sample_id = sample_id if audio_type in SERIALIZABLE_AUDIO_TYPES: - self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) + self.audio = ( + raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) + ) self.duration = read_duration(audio_type, self.audio) if not self.audio_format: self.audio_format = read_format(audio_type, self.audio) else: self.audio = raw_data if self.audio_format is None: - raise ValueError('For audio type "{}" parameter "audio_format" is mandatory'.format(self.audio_type)) + raise ValueError( + 'For audio type "{}" parameter "audio_format" is mandatory'.format( + self.audio_type + ) + ) if audio_type == AUDIO_TYPE_PCM: self.duration = get_pcm_duration(len(self.audio), self.audio_format) elif audio_type == AUDIO_TYPE_NP: self.duration = get_np_duration(len(self.audio), self.audio_format) else: - raise ValueError('Unsupported audio type: {}'.format(self.audio_type)) + raise ValueError("Unsupported audio type: {}".format(self.audio_type)) def change_audio_type(self, new_audio_type, bitrate=None): """ @@ -102,7 +110,10 @@ class Sample: """ if self.audio_type == new_audio_type: return - if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in SERIALIZABLE_AUDIO_TYPES: + if ( + new_audio_type == AUDIO_TYPE_PCM + and self.audio_type in SERIALIZABLE_AUDIO_TYPES + ): self.audio_format, audio = read_audio(self.audio_type, self.audio) self.audio.close() self.audio = audio @@ -114,18 +125,27 @@ class Sample: elif new_audio_type in SERIALIZABLE_AUDIO_TYPES: self.change_audio_type(AUDIO_TYPE_PCM) audio_bytes = io.BytesIO() - write_audio(new_audio_type, audio_bytes, self.audio, audio_format=self.audio_format, bitrate=bitrate) + write_audio( + new_audio_type, + audio_bytes, + self.audio, + audio_format=self.audio_format, + bitrate=bitrate, + ) audio_bytes.seek(0) self.audio = audio_bytes else: - raise RuntimeError('Changing audio representation type from "{}" to "{}" not supported' - .format(self.audio_type, new_audio_type)) + raise RuntimeError( + 'Changing audio representation type from "{}" to "{}" not supported'.format( + self.audio_type, new_audio_type + ) + ) self.audio_type = new_audio_type def _unpack_and_change_audio_type(sample_and_audio_type): packed_sample, audio_type, bitrate = sample_and_audio_type - if hasattr(packed_sample, 'unpack'): + if hasattr(packed_sample, "unpack"): sample = packed_sample.unpack() else: sample = packed_sample @@ -133,20 +153,31 @@ def _unpack_and_change_audio_type(sample_and_audio_type): return sample -def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None): +def change_audio_types( + packed_samples, + audio_type=AUDIO_TYPE_PCM, + bitrate=None, + processes=None, + process_ahead=None, +): with LimitingPool(processes=processes, process_ahead=process_ahead) as pool: - yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples)) + yield from pool.imap( + _unpack_and_change_audio_type, + map(lambda s: (s, audio_type, bitrate), packed_samples), + ) def get_loadable_audio_type_from_extension(ext): return { - '.wav': AUDIO_TYPE_WAV, - '.opus': AUDIO_TYPE_OGG_OPUS, + ".wav": AUDIO_TYPE_WAV, + ".opus": AUDIO_TYPE_OGG_OPUS, }.get(ext, None) def read_audio_format_from_wav_file(wav_file): - return AudioFormat(wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()) + return AudioFormat( + wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth() + ) def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT): @@ -163,13 +194,18 @@ def get_np_duration(np_len, audio_format=DEFAULT_FORMAT): return np_len / audio_format.rate -def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT): +def convert_audio( + src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT +): import sox + transformer = sox.Transformer() - transformer.set_output_format(file_type=file_type, - rate=audio_format.rate, - channels=audio_format.channels, - bits=audio_format.width * 8) + transformer.set_output_format( + file_type=file_type, + rate=audio_format.rate, + channels=audio_format.channels, + bits=audio_format.width * 8, + ) transformer.build(src_audio_path, dst_audio_path) @@ -178,6 +214,7 @@ class AudioFile: Audio data file wrapper that ensures that the file is loaded with the correct sample rate, channels, and width, and converts the file on the fly otherwise. """ + def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT): self.audio_path = audio_path self.audio_format = audio_format @@ -188,8 +225,8 @@ class AudioFile: self.tmp_src_file_path = None def __enter__(self): - if self.audio_path.endswith('.wav'): - self.open_file = open_remote(self.audio_path, 'rb') + if self.audio_path.endswith(".wav"): + self.open_file = open_remote(self.audio_path, "rb") self.open_wav = wave.open(self.open_file) if read_audio_format_from_wav_file(self.open_wav) == self.audio_format: if self.as_path: @@ -202,15 +239,20 @@ class AudioFile: # If the format isn't right, copy the file to local tmp dir and do the conversion on disk if is_remote_path(self.audio_path): - _, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav') + _, self.tmp_src_file_path = tempfile.mkstemp(suffix=".wav") copy_remote(self.audio_path, self.tmp_src_file_path, True) self.audio_path = self.tmp_src_file_path - _, self.tmp_file_path = tempfile.mkstemp(suffix='.wav') - convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format) + _, self.tmp_file_path = tempfile.mkstemp(suffix=".wav") + convert_audio( + self.audio_path, + self.tmp_file_path, + file_type="wav", + audio_format=self.audio_format, + ) if self.as_path: return self.tmp_file_path - self.open_wav = wave.open(self.tmp_file_path, 'rb') + self.open_wav = wave.open(self.tmp_file_path, "rb") return self.open_wav def __exit__(self, *args): @@ -230,33 +272,49 @@ def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False): while True: try: data = wav_file.readframes(frame_size) - if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms: + if ( + not yield_remainder + and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms + ): break yield data except EOFError: break -def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False): +def read_frames_from_file( + audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False +): with AudioFile(audio_path, audio_format=audio_format) as wav_file: - for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder): + for frame in read_frames( + wav_file, + frame_duration_ms=frame_duration_ms, + yield_remainder=yield_remainder, + ): yield frame -def vad_split(audio_frames, - audio_format=DEFAULT_FORMAT, - num_padding_frames=10, - threshold=0.5, - aggressiveness=3): +def vad_split( + audio_frames, + audio_format=DEFAULT_FORMAT, + num_padding_frames=10, + threshold=0.5, + aggressiveness=3, +): from webrtcvad import Vad # pylint: disable=import-outside-toplevel + if audio_format.channels != 1: - raise ValueError('VAD-splitting requires mono samples') + raise ValueError("VAD-splitting requires mono samples") if audio_format.width != 2: - raise ValueError('VAD-splitting requires 16 bit samples') + raise ValueError("VAD-splitting requires 16 bit samples") if audio_format.rate not in [8000, 16000, 32000, 48000]: - raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000') + raise ValueError( + "VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000" + ) if aggressiveness not in [0, 1, 2, 3]: - raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3') + raise ValueError( + "VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3" + ) ring_buffer = collections.deque(maxlen=num_padding_frames) triggered = False vad = Vad(int(aggressiveness)) @@ -266,7 +324,9 @@ def vad_split(audio_frames, for frame_index, frame in enumerate(audio_frames): frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000 if int(frame_duration_ms) not in [10, 20, 30]: - raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms') + raise ValueError( + "VAD-splitting only supported for frame durations 10, 20, or 30 ms" + ) is_speech = vad.is_speech(frame, audio_format.rate) if not triggered: ring_buffer.append((frame, is_speech)) @@ -282,23 +342,23 @@ def vad_split(audio_frames, num_unvoiced = len([f for f, speech in ring_buffer if not speech]) if num_unvoiced > threshold * ring_buffer.maxlen: triggered = False - yield b''.join(voiced_frames), \ - frame_duration_ms * max(0, frame_index - len(voiced_frames)), \ - frame_duration_ms * frame_index + yield b"".join(voiced_frames), frame_duration_ms * max( + 0, frame_index - len(voiced_frames) + ), frame_duration_ms * frame_index ring_buffer.clear() voiced_frames = [] if len(voiced_frames) > 0: - yield b''.join(voiced_frames), \ - frame_duration_ms * (frame_index - len(voiced_frames)), \ - frame_duration_ms * (frame_index + 1) + yield b"".join(voiced_frames), frame_duration_ms * ( + frame_index - len(voiced_frames) + ), frame_duration_ms * (frame_index + 1) def pack_number(n, num_bytes): - return n.to_bytes(num_bytes, 'big', signed=False) + return n.to_bytes(num_bytes, "big", signed=False) def unpack_number(data): - return int.from_bytes(data, 'big', signed=False) + return int.from_bytes(data, "big", signed=False) def get_opus_frame_size(rate): @@ -308,7 +368,8 @@ def get_opus_frame_size(rate): def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT, bitrate=None): frame_size = get_opus_frame_size(audio_format.rate) import opuslib # pylint: disable=import-outside-toplevel - encoder = opuslib.Encoder(audio_format.rate, audio_format.channels, 'audio') + + encoder = opuslib.Encoder(audio_format.rate, audio_format.channels, "audio") if bitrate is not None: encoder.bitrate = bitrate chunk_size = frame_size * audio_format.channels * audio_format.width @@ -317,10 +378,10 @@ def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT, bitrate=None) opus_file.write(pack_number(audio_format.channels, OPUS_CHANNELS_SIZE)) opus_file.write(pack_number(audio_format.width, OPUS_WIDTH_SIZE)) for i in range(0, len(audio_data), chunk_size): - chunk = audio_data[i:i + chunk_size] + chunk = audio_data[i : i + chunk_size] # Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer if len(chunk) < chunk_size: - chunk = chunk + b'\0' * (chunk_size - len(chunk)) + chunk = chunk + b"\0" * (chunk_size - len(chunk)) encoded = encoder.encode(chunk, frame_size) opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE)) opus_file.write(encoded) @@ -339,6 +400,7 @@ def read_opus(opus_file): pcm_buffer_size, audio_format = read_opus_header(opus_file) frame_size = get_opus_frame_size(audio_format.rate) import opuslib # pylint: disable=import-outside-toplevel + decoder = opuslib.Decoder(audio_format.rate, audio_format.channels) audio_data = bytearray() while len(audio_data) < pcm_buffer_size: @@ -357,34 +419,30 @@ def read_ogg_opus(ogg_file): opusfile = pyogg.opus.op_open_memory( ubyte_array.from_buffer(ogg_file_buffer), len(ogg_file_buffer), - ctypes.pointer(error) + ctypes.pointer(error), ) if error.value != 0: raise ValueError( - ("Ogg/Opus buffer could not be read." - "Error code: {}").format(error.value) + ("Ogg/Opus buffer could not be read." "Error code: {}").format(error.value) ) channel_count = pyogg.opus.op_channel_count(opusfile, -1) - sample_rate = 48000 # opus files are always 48kHz - sample_width = 2 # always 16-bit + sample_rate = 48000 # opus files are always 48kHz + sample_width = 2 # always 16-bit audio_format = AudioFormat(sample_rate, channel_count, sample_width) # Allocate sufficient memory to store the entire PCM pcm_size = pyogg.opus.op_pcm_total(opusfile, -1) - Buf = pyogg.opus.opus_int16*(pcm_size*channel_count) + Buf = pyogg.opus.opus_int16 * (pcm_size * channel_count) buf = Buf() # Create a pointer to the newly allocated memory. It # seems we can only do pointer arithmetic on void # pointers. See # https://mattgwwalker.wordpress.com/2020/05/30/pointer-manipulation-in-python/ - buf_ptr = ctypes.cast( - ctypes.pointer(buf), - ctypes.c_void_p - ) - assert buf_ptr.value is not None # for mypy + buf_ptr = ctypes.cast(ctypes.pointer(buf), ctypes.c_void_p) + assert buf_ptr.value is not None # for mypy buf_ptr_zero = buf_ptr.value #: Bytes per sample @@ -396,38 +454,24 @@ def read_ogg_opus(ogg_file): while True: # Calculate remaining buffer size remaining_buffer = ( - len(buf) # int - - (buf_ptr.value - buf_ptr_zero) // bytes_per_sample + len(buf) - (buf_ptr.value - buf_ptr_zero) // bytes_per_sample # int ) # Convert buffer pointer to the desired type - ptr = ctypes.cast( - buf_ptr, - ctypes.POINTER(pyogg.opus.opus_int16) - ) + ptr = ctypes.cast(buf_ptr, ctypes.POINTER(pyogg.opus.opus_int16)) # Read the next section of PCM - ns = pyogg.opus.op_read( - opusfile, - ptr, - remaining_buffer, - pyogg.ogg.c_int_p() - ) + ns = pyogg.opus.op_read(opusfile, ptr, remaining_buffer, pyogg.ogg.c_int_p()) # Check for errors if ns < 0: raise ValueError( - "Error while reading OggOpus buffer. "+ - "Error code: {}".format(ns) + "Error while reading OggOpus buffer. " + "Error code: {}".format(ns) ) # Increment the pointer - buf_ptr.value += ( - ns - * bytes_per_sample - * channel_count - ) - assert buf_ptr.value is not None # for mypy + buf_ptr.value += ns * bytes_per_sample * channel_count + assert buf_ptr.value is not None # for mypy samples += ns @@ -448,7 +492,7 @@ def read_ogg_opus(ogg_file): def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): # wav_file is already a file-pointer here - with wave.open(wav_file, 'wb') as wav_file_writer: + with wave.open(wav_file, "wb") as wav_file_writer: wav_file_writer.setframerate(audio_format.rate) wav_file_writer.setnchannels(audio_format.channels) wav_file_writer.setsampwidth(audio_format.width) @@ -457,7 +501,7 @@ def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): def read_wav(wav_file): wav_file.seek(0) - with wave.open(wav_file, 'rb') as wav_file_reader: + with wave.open(wav_file, "rb") as wav_file_reader: audio_format = read_audio_format_from_wav_file(wav_file_reader) pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) return audio_format, pcm_data @@ -470,20 +514,24 @@ def read_audio(audio_type, audio_file): return read_opus(audio_file) if audio_type == AUDIO_TYPE_OGG_OPUS: return read_ogg_opus(audio_file) - raise ValueError('Unsupported audio type: {}'.format(audio_type)) + raise ValueError("Unsupported audio type: {}".format(audio_type)) -def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, bitrate=None): +def write_audio( + audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, bitrate=None +): if audio_type == AUDIO_TYPE_WAV: return write_wav(audio_file, pcm_data, audio_format=audio_format) if audio_type == AUDIO_TYPE_OPUS: - return write_opus(audio_file, pcm_data, audio_format=audio_format, bitrate=bitrate) - raise ValueError('Unsupported audio type: {}'.format(audio_type)) + return write_opus( + audio_file, pcm_data, audio_format=audio_format, bitrate=bitrate + ) + raise ValueError("Unsupported audio type: {}".format(audio_type)) def read_wav_duration(wav_file): wav_file.seek(0) - with wave.open(wav_file, 'rb') as wav_file_reader: + with wave.open(wav_file, "rb") as wav_file_reader: return wav_file_reader.getnframes() / wav_file_reader.getframerate() @@ -499,19 +547,18 @@ def read_ogg_opus_duration(ogg_file): opusfile = pyogg.opus.op_open_memory( ubyte_array.from_buffer(ogg_file_buffer), len(ogg_file_buffer), - ctypes.pointer(error) + ctypes.pointer(error), ) if error.value != 0: raise ValueError( - ("Ogg/Opus buffer could not be read." - "Error code: {}").format(error.value) + ("Ogg/Opus buffer could not be read." "Error code: {}").format(error.value) ) pcm_buffer_size = pyogg.opus.op_pcm_total(opusfile, -1) channel_count = pyogg.opus.op_channel_count(opusfile, -1) - sample_rate = 48000 # opus files are always 48kHz - sample_width = 2 # always 16-bit + sample_rate = 48000 # opus files are always 48kHz + sample_width = 2 # always 16-bit audio_format = AudioFormat(sample_rate, channel_count, sample_width) pyogg.opus.op_free(opusfile) return get_pcm_duration(pcm_buffer_size, audio_format) @@ -524,12 +571,12 @@ def read_duration(audio_type, audio_file): return read_opus_duration(audio_file) if audio_type == AUDIO_TYPE_OGG_OPUS: return read_ogg_opus_duration(audio_file) - raise ValueError('Unsupported audio type: {}'.format(audio_type)) + raise ValueError("Unsupported audio type: {}".format(audio_type)) def read_wav_format(wav_file): wav_file.seek(0) - with wave.open(wav_file, 'rb') as wav_file_reader: + with wave.open(wav_file, "rb") as wav_file_reader: return read_audio_format_from_wav_file(wav_file_reader) @@ -545,20 +592,19 @@ def read_ogg_opus_format(ogg_file): opusfile = pyogg.opus.op_open_memory( ubyte_array.from_buffer(ogg_file_buffer), len(ogg_file_buffer), - ctypes.pointer(error) + ctypes.pointer(error), ) if error.value != 0: raise ValueError( - ("Ogg/Opus buffer could not be read." - "Error code: {}").format(error.value) + ("Ogg/Opus buffer could not be read." "Error code: {}").format(error.value) ) channel_count = pyogg.opus.op_channel_count(opusfile, -1) pyogg.opus.op_free(opusfile) - sample_rate = 48000 # opus files are always 48kHz - sample_width = 2 # always 16-bit + sample_rate = 48000 # opus files are always 48kHz + sample_width = 2 # always 16-bit return AudioFormat(sample_rate, channel_count, sample_width) @@ -569,12 +615,12 @@ def read_format(audio_type, audio_file): return read_opus_format(audio_file) if audio_type == AUDIO_TYPE_OGG_OPUS: return read_ogg_opus_format(audio_file) - raise ValueError('Unsupported audio type: {}'.format(audio_type)) + raise ValueError("Unsupported audio type: {}".format(audio_type)) def get_dtype(audio_format): if audio_format.width not in [1, 2, 4]: - raise ValueError('Unsupported sample width: {}'.format(audio_format.width)) + raise ValueError("Unsupported sample width: {}".format(audio_format.width)) return [None, np.int8, np.int16, None, np.int32][audio_format.width] @@ -589,8 +635,8 @@ def pcm_to_np(pcm_data, audio_format=DEFAULT_FORMAT): # Read interleaved channels nchannels = audio_format.channels - samples = samples.reshape((int(len(samples)/nchannels), nchannels)) - + samples = samples.reshape((int(len(samples) / nchannels), nchannels)) + # Convert to 0.0-1.0 range samples = samples.astype(np.float32) / np.iinfo(dtype).max @@ -624,4 +670,7 @@ def gain_db_to_ratio(gain_db): def normalize_audio(sample_data, dbfs=3.0103): - return np.maximum(np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), 1.0), -1.0) + return np.maximum( + np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), 1.0), + -1.0, + ) diff --git a/training/coqui_stt_training/util/augmentations.py b/training/coqui_stt_training/util/augmentations.py index dee7a725..f32c4214 100644 --- a/training/coqui_stt_training/util/augmentations.py +++ b/training/coqui_stt_training/util/augmentations.py @@ -1,18 +1,33 @@ -import os -import re import math +import os import random -import resampy -import numpy as np +import re +from multiprocessing import Process, Queue -from multiprocessing import Queue, Process -from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS -from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE -from .sample_collections import samples_from_source, unpack_maybe +import numpy as np +import resampy + +from .audio import ( + AUDIO_TYPE_NP, + AUDIO_TYPE_OPUS, + AUDIO_TYPE_PCM, + gain_db_to_ratio, + max_dbfs, + normalize_audio, +) +from .helpers import ( + MEGABYTE, + LimitingPool, + float_range, + int_range, + pick_value_from_range, + tf_pick_value_from_range, +) from .logging import log_info +from .sample_collections import samples_from_source, unpack_maybe BUFFER_SIZE = 1 * MEGABYTE -SPEC_PARSER = re.compile(r'^(?P[a-z_]+)(\[(?P.*)\])?$') +SPEC_PARSER = re.compile(r"^(?P[a-z_]+)(\[(?P.*)\])?$") class Augmentation: @@ -32,10 +47,10 @@ class SampleAugmentation(Augmentation): class GraphAugmentation(Augmentation): - def __init__(self, p=1.0, domain='spectrogram'): + def __init__(self, p=1.0, domain="spectrogram"): super(GraphAugmentation, self).__init__(p) - if domain not in ['signal', 'spectrogram', 'features']: - raise ValueError('Unsupported augmentation domain: {}'.format(domain)) + if domain not in ["signal", "spectrogram", "features"]: + raise ValueError("Unsupported augmentation domain: {}".format(domain)) self.domain = domain def apply(self, tensor, transcript=None, clock=0.0): @@ -43,19 +58,31 @@ class GraphAugmentation(Augmentation): def apply_with_probability(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel - rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max)) - return tf.cond(tf.less(rv, self.probability), - lambda: self.apply(tensor, transcript=transcript, clock=clock), - lambda: tensor) + + rv = tf.random.stateless_uniform( + [], seed=(clock * tf.int32.min, clock * tf.int32.max) + ) + return tf.cond( + tf.less(rv, self.probability), + lambda: self.apply(tensor, transcript=transcript, clock=clock), + lambda: tensor, + ) def maybe_apply(self, domain, tensor, transcript=None, clock=0.0): if domain == self.domain: - return self.apply_with_probability(tensor, transcript=transcript, clock=clock) + return self.apply_with_probability( + tensor, transcript=transcript, clock=clock + ) return tensor def units_per_ms(self): from .flags import FLAGS # pylint: disable=import-outside-toplevel - return FLAGS.audio_sample_rate / 1000.0 if self.domain == 'signal' else 1.0 / FLAGS.feature_win_step + + return ( + FLAGS.audio_sample_rate / 1000.0 + if self.domain == "signal" + else 1.0 / FLAGS.feature_win_step + ) def parse_augmentation(augmentation_spec): @@ -73,24 +100,34 @@ def parse_augmentation(augmentation_spec): """ match = SPEC_PARSER.match(augmentation_spec) if not match: - raise ValueError('Augmentation specification has wrong format') - cls_name = ''.join(map(lambda p: p[0].upper() + p[1:], match.group('cls').split('_'))) + raise ValueError("Augmentation specification has wrong format") + cls_name = "".join( + map(lambda p: p[0].upper() + p[1:], match.group("cls").split("_")) + ) augmentation_cls = globals()[cls_name] if cls_name in globals() else None - if augmentation_cls is None or not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation: - raise ValueError('Unknown augmentation: {}'.format(cls_name)) - parameters = match.group('params') - parameters = [] if parameters is None else parameters.split(',') + if ( + augmentation_cls is None + or not issubclass(augmentation_cls, Augmentation) + or augmentation_cls == Augmentation + ): + raise ValueError("Unknown augmentation: {}".format(cls_name)) + parameters = match.group("params") + parameters = [] if parameters is None else parameters.split(",") args = [] kwargs = {} for parameter in parameters: - pair = tuple(list(map(str.strip, (parameter.split('='))))) + pair = tuple(list(map(str.strip, (parameter.split("="))))) if len(pair) == 1: args.append(pair) elif len(pair) == 2: kwargs[pair[0]] = pair[1] else: - raise ValueError('Unable to parse augmentation value assignment') - log_info('Processed augmentation type: [{}] with parameter settings: {}'.format(augmentation_cls.__name__, kwargs)) + raise ValueError("Unable to parse augmentation value assignment") + log_info( + "Processed augmentation type: [{}] with parameter settings: {}".format( + augmentation_cls.__name__, kwargs + ) + ) return augmentation_cls(*args, **kwargs) @@ -110,7 +147,9 @@ def parse_augmentations(augmentation_specs): return list(map(parse_augmentation, augmentation_specs or [])) -def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, clock=0.0): +def apply_graph_augmentations( + domain, tensor, augmentations, transcript=None, clock=0.0 +): """ Augments training sample tensor of a certain domain with matching augmentations of passed list. @@ -134,7 +173,9 @@ def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, cl if augmentations: for augmentation in augmentations: if isinstance(augmentation, GraphAugmentation): - tensor = augmentation.maybe_apply(domain, tensor, transcript=transcript, clock=clock) + tensor = augmentation.maybe_apply( + domain, tensor, transcript=transcript, clock=clock + ) return tensor @@ -168,13 +209,15 @@ def _augment_sample(timed_sample, context=None): return sample -def apply_sample_augmentations(samples, - augmentations, - audio_type=AUDIO_TYPE_NP, - buffering=BUFFER_SIZE, - process_ahead=None, - clock=0.0, - final_clock=None): +def apply_sample_augmentations( + samples, + augmentations, + audio_type=AUDIO_TYPE_NP, + buffering=BUFFER_SIZE, + process_ahead=None, + clock=0.0, + final_clock=None, +): """ Prepares samples for being used during training. This includes parallel and buffered application of augmentations and a conversion to a specified audio-type. @@ -201,20 +244,27 @@ def apply_sample_augmentations(samples, ------- iterable of util.sample_collections.LabeledSample or util.audio.Sample """ + def timed_samples(): if final_clock is None: for sample in samples: yield sample, clock else: for sample_index, sample in enumerate(samples): - sample_clock = clock + (final_clock - clock) * (sample_index / len(samples)) + sample_clock = clock + (final_clock - clock) * ( + sample_index / len(samples) + ) yield sample, sample_clock assert 0.0 <= clock <= 1.0 if final_clock is not None: assert 0.0 <= final_clock <= 1.0 assert clock <= final_clock - augmentations = [aug for aug in augmentations if isinstance(aug, SampleAugmentation)] if augmentations else [] + augmentations = ( + [aug for aug in augmentations if isinstance(aug, SampleAugmentation)] + if augmentations + else [] + ) try: for augmentation in augmentations: augmentation.start(buffering=buffering) @@ -223,9 +273,11 @@ def apply_sample_augmentations(samples, for timed_sample in timed_samples(): yield _load_and_augment_sample(timed_sample, context=context) else: - with LimitingPool(process_ahead=process_ahead, - initializer=_init_augmentation_worker, - initargs=(context,)) as pool: + with LimitingPool( + process_ahead=process_ahead, + initializer=_init_augmentation_worker, + initargs=(context,), + ) as pool: yield from pool.imap(_load_and_augment_sample, timed_samples()) finally: for augmentation in augmentations: @@ -247,6 +299,7 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE): class Overlay(SampleAugmentation): """See "Overlay augmentation" in training documentation""" + def __init__(self, source, p=1.0, snr=3.0, layers=1): super(Overlay, self).__init__(p) self.source = source @@ -257,10 +310,14 @@ class Overlay(SampleAugmentation): self.enqueue_process = None def start(self, buffering=BUFFER_SIZE): - self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count()))) - self.enqueue_process = Process(target=_enqueue_overlay_samples, - args=(self.source, self.queue), - kwargs={'buffering': buffering}) + self.queue = Queue( + max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())) + ) + self.enqueue_process = Process( + target=_enqueue_overlay_samples, + args=(self.source, self.queue), + kwargs={"buffering": buffering}, + ) self.enqueue_process.start() def apply(self, sample, clock=0.0): @@ -280,11 +337,15 @@ class Overlay(SampleAugmentation): n_required = len(audio) - overlay_offset n_current = len(self.current_sample) if n_required >= n_current: # take it completely - overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample + overlay_data[ + overlay_offset : overlay_offset + n_current + ] += self.current_sample overlay_offset += n_current self.current_sample = None else: # take required slice from head and keep tail for next layer or sample - overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required] + overlay_data[ + overlay_offset : overlay_offset + n_required + ] += self.current_sample[0:n_required] overlay_offset += n_required self.current_sample = self.current_sample[n_required:] snr_db = pick_value_from_range(self.snr, clock=clock) @@ -303,18 +364,24 @@ class Overlay(SampleAugmentation): class Codec(SampleAugmentation): """See "Codec augmentation" in training documentation""" + def __init__(self, p=1.0, bitrate=3200): super(Codec, self).__init__(p) self.bitrate = int_range(bitrate) def apply(self, sample, clock=0.0): bitrate = pick_value_from_range(self.bitrate, clock=clock) - sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again - sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream + sample.change_audio_type( + new_audio_type=AUDIO_TYPE_PCM + ) # decoding to ensure it has to get encoded again + sample.change_audio_type( + new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate + ) # will get decoded again downstream class Reverb(SampleAugmentation): """See "Reverb augmentation" in training documentation""" + def __init__(self, p=1.0, delay=20.0, decay=10.0): super(Reverb, self).__init__(p) self.delay = float_range(delay) @@ -331,13 +398,17 @@ class Reverb(SampleAugmentation): primes = [17, 19, 23, 29, 31] for delay_prime in primes: # primes to minimize comb filter interference layer = np.copy(audio) - n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0) - n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero + n_delay = math.floor( + delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0 + ) + n_delay = max( + 16, n_delay + ) # 16 samples minimum to avoid performance trap and risk of division by zero for w_index in range(0, math.floor(len(audio) / n_delay)): w1 = w_index * n_delay w2 = (w_index + 1) * n_delay width = min(len(audio) - w2, n_delay) # last window could be smaller - layer[w2:w2 + width] += decay * layer[w1:w1 + width] + layer[w2 : w2 + width] += decay * layer[w1 : w1 + width] result += layer audio = normalize_audio(result, dbfs=orig_dbfs) sample.audio = np.array(audio, dtype=np.float32) @@ -345,6 +416,7 @@ class Reverb(SampleAugmentation): class Resample(SampleAugmentation): """See "Resample augmentation" in training documentation""" + def __init__(self, p=1.0, rate=8000): super(Resample, self).__init__(p) self.rate = int_range(rate) @@ -353,8 +425,12 @@ class Resample(SampleAugmentation): sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) rate = pick_value_from_range(self.rate, clock=clock) orig_len = len(sample.audio) - resampled = resampy.resample(sample.audio, sample.audio_format.rate, rate, axis=0, filter='kaiser_fast') - sample.audio = resampy.resample(resampled, rate, sample.audio_format.rate, axis=0, filter='kaiser_fast')[:orig_len] + resampled = resampy.resample( + sample.audio, sample.audio_format.rate, rate, axis=0, filter="kaiser_fast" + ) + sample.audio = resampy.resample( + resampled, rate, sample.audio_format.rate, axis=0, filter="kaiser_fast" + )[:orig_len] class NormalizeSampleRate(SampleAugmentation): @@ -367,12 +443,19 @@ class NormalizeSampleRate(SampleAugmentation): return sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) - sample.audio = resampy.resample(sample.audio, sample.audio_format.rate, self.rate, axis=0, filter='kaiser_fast') + sample.audio = resampy.resample( + sample.audio, + sample.audio_format.rate, + self.rate, + axis=0, + filter="kaiser_fast", + ) sample.audio_format = sample.audio_format._replace(rate=self.rate) class Volume(SampleAugmentation): """See "Volume augmentation" in training documentation""" + def __init__(self, p=1.0, dbfs=3.0103): super(Volume, self).__init__(p) self.target_dbfs = float_range(dbfs) @@ -385,55 +468,76 @@ class Volume(SampleAugmentation): class Pitch(GraphAugmentation): """See "Pitch augmentation" in training documentation""" + def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)): - super(Pitch, self).__init__(p, domain='spectrogram') + super(Pitch, self).__init__(p, domain="spectrogram") self.pitch = float_range(pitch) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + original_shape = tf.shape(tensor) pitch = tf_pick_value_from_range(self.pitch, clock=clock) - new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32) - spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size]) - spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, - offset_height=0, - offset_width=0, - target_height=original_shape[1], - target_width=tf.math.minimum(original_shape[2], new_freq_size)) - spectrogram_aug = tf.cond(pitch < 1, - lambda: tf.image.pad_to_bounding_box(spectrogram_aug, - offset_height=0, - offset_width=0, - target_height=tf.shape(spectrogram_aug)[1], - target_width=original_shape[2]), - lambda: spectrogram_aug) + new_freq_size = tf.cast( + tf.cast(original_shape[2], tf.float32) * pitch, tf.int32 + ) + spectrogram_aug = tf.image.resize_bilinear( + tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size] + ) + spectrogram_aug = tf.image.crop_to_bounding_box( + spectrogram_aug, + offset_height=0, + offset_width=0, + target_height=original_shape[1], + target_width=tf.math.minimum(original_shape[2], new_freq_size), + ) + spectrogram_aug = tf.cond( + pitch < 1, + lambda: tf.image.pad_to_bounding_box( + spectrogram_aug, + offset_height=0, + offset_width=0, + target_height=tf.shape(spectrogram_aug)[1], + target_width=original_shape[2], + ), + lambda: spectrogram_aug, + ) return spectrogram_aug[:, :, :, 0] class Tempo(GraphAugmentation): """See "Tempo augmentation" in training documentation""" + def __init__(self, p=1.0, factor=1.1, max_time=-1): - super(Tempo, self).__init__(p, domain='spectrogram') + super(Tempo, self).__init__(p, domain="spectrogram") self.factor = float_range(factor) self.max_time = float(max_time) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + factor = tf_pick_value_from_range(self.factor, clock=clock) original_shape = tf.shape(tensor) - new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32) + new_time_size = tf.cast( + tf.cast(original_shape[1], tf.float32) / factor, tf.int32 + ) if transcript is not None: new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1]) if self.max_time > 0: - new_time_size = tf.math.minimum(new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32)) - spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]]) + new_time_size = tf.math.minimum( + new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32) + ) + spectrogram_aug = tf.image.resize_bilinear( + tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]] + ) return spectrogram_aug[:, :, :, 0] class Warp(GraphAugmentation): """See "Warp augmentation" in training documentation""" + def __init__(self, p=1.0, nt=1, nf=1, wt=0.1, wf=0.0): - super(Warp, self).__init__(p, domain='spectrogram') + super(Warp, self).__init__(p, domain="spectrogram") self.num_t = int_range(nt) self.num_f = int_range(nf) self.warp_t = float_range(wt) @@ -441,6 +545,7 @@ class Warp(GraphAugmentation): def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + original_shape = tf.shape(tensor) size_t, size_f = original_shape[1], original_shape[2] seed = (clock * tf.int32.min, clock * tf.int32.max) @@ -449,25 +554,43 @@ class Warp(GraphAugmentation): def get_flows(n, size, warp): warp = tf_pick_value_from_range(warp, clock=clock) - warp = warp * tf.cast(size, dtype=tf.float32) / tf.cast(2 * (n + 1), dtype=tf.float32) - f = tf.random.stateless_normal([num_t, num_f], seed, mean=0.0, stddev=warp, dtype=tf.float32) - return tf.pad(f, tf.constant([[1, 1], [1, 1]]), 'CONSTANT') # zero flow at all edges + warp = ( + warp + * tf.cast(size, dtype=tf.float32) + / tf.cast(2 * (n + 1), dtype=tf.float32) + ) + f = tf.random.stateless_normal( + [num_t, num_f], seed, mean=0.0, stddev=warp, dtype=tf.float32 + ) + return tf.pad( + f, tf.constant([[1, 1], [1, 1]]), "CONSTANT" + ) # zero flow at all edges - flows = tf.stack([get_flows(num_t, size_t, self.warp_t), get_flows(num_f, size_f, self.warp_f)], axis=2) + flows = tf.stack( + [ + get_flows(num_t, size_t, self.warp_t), + get_flows(num_f, size_f, self.warp_f), + ], + axis=2, + ) flows = tf.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f]) - spectrogram_aug = tf.contrib.image.dense_image_warp(tf.expand_dims(tensor, -1), flows) + spectrogram_aug = tf.contrib.image.dense_image_warp( + tf.expand_dims(tensor, -1), flows + ) return tf.reshape(spectrogram_aug, shape=(1, -1, size_f)) class FrequencyMask(GraphAugmentation): """See "Frequency mask augmentation" in training documentation""" + def __init__(self, p=1.0, n=3, size=2): - super(FrequencyMask, self).__init__(p, domain='spectrogram') + super(FrequencyMask, self).__init__(p, domain="spectrogram") self.n = int_range(n) # pylint: disable=invalid-name self.size = int_range(size) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + time_max = tf.shape(tensor)[1] freq_max = tf.shape(tensor)[2] n = tf_pick_value_from_range(self.n, clock=clock) @@ -476,10 +599,21 @@ class FrequencyMask(GraphAugmentation): size = tf_pick_value_from_range(self.size, clock=clock) size = tf.math.maximum(1, tf.math.minimum(freq_max - 1, size)) seed = tf.cast(clock * tf.int32.max, tf.int32) - i - f0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=freq_max - size, dtype=tf.dtypes.int32) - freq_mask = tf.concat([tf.ones([1, time_max, f0]), - tf.zeros([1, time_max, size]), - tf.ones([1, time_max, freq_max - f0 - size])], axis=2) + f0 = tf.random.stateless_uniform( + (), + (-seed, seed), + minval=0, + maxval=freq_max - size, + dtype=tf.dtypes.int32, + ) + freq_mask = tf.concat( + [ + tf.ones([1, time_max, f0]), + tf.zeros([1, time_max, size]), + tf.ones([1, time_max, freq_max - f0 - size]), + ], + axis=2, + ) return i + 1, spectrogram_aug * freq_mask return tf.while_loop(lambda i, spectrogram_aug: i < n, body, (0, tensor))[1] @@ -487,29 +621,51 @@ class FrequencyMask(GraphAugmentation): class TimeMask(GraphAugmentation): """See "Time mask augmentation" in training documentation""" - def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0): + + def __init__(self, p=1.0, domain="spectrogram", n=3, size=10.0): super(TimeMask, self).__init__(p, domain=domain) self.n = int_range(n) # pylint: disable=invalid-name self.size = float_range(size) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel - time_max = tf.shape(tensor)[0 if self.domain == 'signal' else 1] + + time_max = tf.shape(tensor)[0 if self.domain == "signal" else 1] n = tf_pick_value_from_range(self.n, clock=clock) def body(i, augmented): - size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(), dtype=tf.int32) + size = tf.cast( + tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(), + dtype=tf.int32, + ) size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size)) seed = tf.cast(clock * tf.int32.max, tf.int32) - i - t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32) + t0 = tf.random.stateless_uniform( + (), + (-seed, seed), + minval=0, + maxval=time_max - size, + dtype=tf.dtypes.int32, + ) rest = time_max - t0 - size - if self.domain == 'spectrogram': + if self.domain == "spectrogram": fm = tf.shape(tensor)[2] - time_mask = tf.concat([tf.ones([1, t0, fm]), tf.zeros([1, size, fm]), tf.ones([1, rest, fm])], axis=1) - elif self.domain == 'signal': - time_mask = tf.concat([tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0) + time_mask = tf.concat( + [ + tf.ones([1, t0, fm]), + tf.zeros([1, size, fm]), + tf.ones([1, rest, fm]), + ], + axis=1, + ) + elif self.domain == "signal": + time_mask = tf.concat( + [tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0 + ) else: - time_mask = tf.concat([tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1) + time_mask = tf.concat( + [tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1 + ) return i + 1, augmented * time_mask return tf.while_loop(lambda i, augmented: i < n, body, (0, tensor))[1] @@ -517,43 +673,55 @@ class TimeMask(GraphAugmentation): class Dropout(GraphAugmentation): """See "Dropout augmentation" in training documentation""" - def __init__(self, p=1.0, domain='spectrogram', rate=0.05): + + def __init__(self, p=1.0, domain="spectrogram", rate=0.05): super(Dropout, self).__init__(p, domain=domain) self.rate = float_range(rate) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + rate = tf_pick_value_from_range(self.rate, clock=clock) rate = tf.math.maximum(0.0, rate) - factors = tf.random.stateless_uniform(tf.shape(tensor), - (clock * tf.int32.min, clock * tf.int32.max), - minval=0.0, - maxval=1.0, - dtype=tf.float32) + factors = tf.random.stateless_uniform( + tf.shape(tensor), + (clock * tf.int32.min, clock * tf.int32.max), + minval=0.0, + maxval=1.0, + dtype=tf.float32, + ) return tensor * tf.math.sign(tf.math.floor(factors + rate)) class Add(GraphAugmentation): """See "Add augmentation" in training documentation""" - def __init__(self, p=1.0, domain='features', stddev=5): + + def __init__(self, p=1.0, domain="features", stddev=5): super(Add, self).__init__(p, domain=domain) self.stddev = float_range(stddev) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + stddev = tf_pick_value_from_range(self.stddev, clock=clock) seed = (clock * tf.int32.min, clock * tf.int32.max) - return tensor + tf.random.stateless_normal(tf.shape(tensor), seed, mean=0.0, stddev=stddev) + return tensor + tf.random.stateless_normal( + tf.shape(tensor), seed, mean=0.0, stddev=stddev + ) class Multiply(GraphAugmentation): """See "Multiply augmentation" in training documentation""" - def __init__(self, p=1.0, domain='features', stddev=5): + + def __init__(self, p=1.0, domain="features", stddev=5): super(Multiply, self).__init__(p, domain=domain) self.stddev = float_range(stddev) def apply(self, tensor, transcript=None, clock=0.0): import tensorflow as tf # pylint: disable=import-outside-toplevel + stddev = tf_pick_value_from_range(self.stddev, clock=clock) seed = (clock * tf.int32.min, clock * tf.int32.max) - return tensor * tf.random.stateless_normal(tf.shape(tensor), seed, mean=1.0, stddev=stddev) + return tensor * tf.random.stateless_normal( + tf.shape(tensor), seed, mean=1.0, stddev=stddev + ) diff --git a/training/coqui_stt_training/util/check_characters.py b/training/coqui_stt_training/util/check_characters.py index cf7af861..fc8630a8 100644 --- a/training/coqui_stt_training/util/check_characters.py +++ b/training/coqui_stt_training/util/check_characters.py @@ -22,14 +22,31 @@ import csv import os import sys import unicodedata + from .io import open_remote + def main(): parser = argparse.ArgumentParser() - parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True) - parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true") - parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true") + parser.add_argument( + "-csv", + "--csv-files", + help="Str. Filenames as a comma separated list", + required=True, + ) + parser.add_argument( + "-alpha", + "--alphabet-format", + help="Bool. Print in format for alphabet.txt", + action="store_true", + ) + parser.add_argument( + "-unicode", + "--disable-unicode-variants", + help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", + action="store_true", + ) args = parser.parse_args() in_files = args.csv_files.split(",") @@ -46,11 +63,21 @@ def main(): if not args.disable_unicode_variants: unicode_transcript = unicodedata.normalize("NFKC", row[2]) if row[2] != unicode_transcript: - print("Your input file", in_file, "contains at least one transript with unicode chars on more than one code-point: '{}'. Consider using NFKC normalization: unicodedata.normalize('NFKC', str).".format(row[2])) + print( + "Your input file", + in_file, + "contains at least one transript with unicode chars on more than one code-point: '{}'. Consider using NFKC normalization: unicodedata.normalize('NFKC', str).".format( + row[2] + ), + ) sys.exit(-1) all_text |= set(row[2]) except IndexError: - print("Your input file", in_file, "is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript") + print( + "Your input file", + in_file, + "is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript", + ) sys.exit(-1) finally: csv_file.close() @@ -63,5 +90,6 @@ def main(): else: print(list(all_text)) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/training/coqui_stt_training/util/checkpoints.py b/training/coqui_stt_training/util/checkpoints.py index 459a4d06..e88fd3f6 100644 --- a/training/coqui_stt_training/util/checkpoints.py +++ b/training/coqui_stt_training/util/checkpoints.py @@ -1,9 +1,10 @@ import sys + import tensorflow as tf import tensorflow.compat.v1 as tfv1 from .flags import FLAGS -from .logging import log_info, log_error, log_warn +from .logging import log_error, log_info, log_warn def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): @@ -17,9 +18,11 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= # We explicitly allow the learning rate variable to be missing for backwards # compatibility with older checkpoints. - lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') - if lr_var and ('learning_rate' not in vars_in_ckpt or - (FLAGS.force_initialize_learning_rate and allow_lr_init)): + lr_var = set(v for v in load_vars if v.op.name == "learning_rate") + if lr_var and ( + "learning_rate" not in vars_in_ckpt + or (FLAGS.force_initialize_learning_rate and allow_lr_init) + ): assert len(lr_var) <= 1 load_vars -= lr_var init_vars |= lr_var @@ -31,7 +34,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= missing_vars = set() for v in load_vars: if v.op.name not in vars_in_ckpt: - log_warn('CUDNN variable not found: %s' % (v.op.name)) + log_warn("CUDNN variable not found: %s" % (v.op.name)) missing_vars.add(v) init_vars.add(v) @@ -40,10 +43,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= # Check that the only missing variables (i.e. those to be initialised) # are the Adam moment tensors, if they aren't then we have an issue missing_var_names = [v.op.name for v in missing_vars] - if any('Adam' not in v for v in missing_var_names): - log_error('Tried to load a CuDNN RNN checkpoint but there were ' - 'more missing variables than just the Adam moment ' - 'tensors. Missing variables: {}'.format(missing_var_names)) + if any("Adam" not in v for v in missing_var_names): + log_error( + "Tried to load a CuDNN RNN checkpoint but there were " + "more missing variables than just the Adam moment " + "tensors. Missing variables: {}".format(missing_var_names) + ) sys.exit(1) if allow_drop_layers and FLAGS.drop_source_layers > 0: @@ -54,12 +59,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= # If we want to use all layers from the source model except # the last one, we use this: drop_source_layers=1 if FLAGS.drop_source_layers >= 6: - log_warn('The checkpoint only has 6 layers, but you are trying to drop ' - 'all of them or more than all of them. Continuing and ' - 'dropping only 5 layers.') + log_warn( + "The checkpoint only has 6 layers, but you are trying to drop " + "all of them or more than all of them. Continuing and " + "dropping only 5 layers." + ) FLAGS.drop_source_layers = 5 - dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):] + dropped_layers = ["2", "3", "lstm", "5", "6"][ + -1 * int(FLAGS.drop_source_layers) : + ] # Initialize all variables needed for DS, but not loaded from ckpt for v in load_vars: if any(layer in v.op.name for layer in dropped_layers): @@ -67,16 +76,18 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= load_vars -= init_vars for v in sorted(load_vars, key=lambda v: v.op.name): - log_info('Loading variable from checkpoint: %s' % (v.op.name)) + log_info("Loading variable from checkpoint: %s" % (v.op.name)) v.load(ckpt.get_tensor(v.op.name), session=session) for v in sorted(init_vars, key=lambda v: v.op.name): - log_info('Initializing variable: %s' % (v.op.name)) + log_info("Initializing variable: %s" % (v.op.name)) session.run(v.initializer) def _checkpoint_path_or_none(checkpoint_filename): - checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename) + checkpoint = tfv1.train.get_checkpoint_state( + FLAGS.load_checkpoint_dir, checkpoint_filename + ) if not checkpoint: return None return checkpoint.model_checkpoint_path @@ -91,61 +102,65 @@ def _initialize_all_variables(session): def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): for method in method_order: # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' - if method == 'best': - ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint') + if method == "best": + ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint") if ckpt_path: - log_info('Loading best validating checkpoint from {}'.format(ckpt_path)) - return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init) - log_info('Could not find best validating checkpoint.') + log_info("Loading best validating checkpoint from {}".format(ckpt_path)) + return _load_checkpoint( + session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init + ) + log_info("Could not find best validating checkpoint.") # Load most recent checkpoint, saved in checkpoint file 'checkpoint' - elif method == 'last': - ckpt_path = _checkpoint_path_or_none('checkpoint') + elif method == "last": + ckpt_path = _checkpoint_path_or_none("checkpoint") if ckpt_path: - log_info('Loading most recent checkpoint from {}'.format(ckpt_path)) - return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init) - log_info('Could not find most recent checkpoint.') + log_info("Loading most recent checkpoint from {}".format(ckpt_path)) + return _load_checkpoint( + session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init + ) + log_info("Could not find most recent checkpoint.") # Initialize all variables - elif method == 'init': - log_info('Initializing all variables.') + elif method == "init": + log_info("Initializing all variables.") return _initialize_all_variables(session) else: - log_error('Unknown initialization method: {}'.format(method)) + log_error("Unknown initialization method: {}".format(method)) sys.exit(1) - log_error('All initialization methods failed ({}).'.format(method_order)) + log_error("All initialization methods failed ({}).".format(method_order)) sys.exit(1) def reload_best_checkpoint(session): - _load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False) + _load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False) def load_or_init_graph_for_training(session): - ''' + """ Load variables from checkpoint or initialize variables. By default this will try to load the best validating checkpoint, then try the last checkpoint, and finally initialize the weights from scratch. This can be overriden with the `--load_train` flag. See its documentation for more info. - ''' - if FLAGS.load_train == 'auto': - methods = ['best', 'last', 'init'] + """ + if FLAGS.load_train == "auto": + methods = ["best", "last", "init"] else: methods = [FLAGS.load_train] _load_or_init_impl(session, methods, allow_drop_layers=True) def load_graph_for_evaluation(session): - ''' + """ Load variables from checkpoint. Initialization is not allowed. By default this will try to load the best validating checkpoint, then try the last checkpoint. This can be overriden with the `--load_evaluate` flag. See its documentation for more info. - ''' - if FLAGS.load_evaluate == 'auto': - methods = ['best', 'last'] + """ + if FLAGS.load_evaluate == "auto": + methods = ["best", "last"] else: methods = [FLAGS.load_evaluate] _load_or_init_impl(session, methods, allow_drop_layers=False) diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index f3b25bb3..8c468a58 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -2,18 +2,20 @@ from __future__ import absolute_import, division, print_function import os import sys -import tensorflow.compat.v1 as tfv1 from attrdict import AttrDict -from xdg import BaseDirectory as xdg from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet +from xdg import BaseDirectory as xdg +import tensorflow.compat.v1 as tfv1 + +from .augmentations import NormalizeSampleRate, parse_augmentations from .flags import FLAGS from .gpu import get_available_gpus -from .logging import log_error, log_warn from .helpers import parse_file_size -from .augmentations import parse_augmentations, NormalizeSampleRate from .io import path_exists_remote +from .logging import log_error, log_warn + class ConfigSingleton: _config = None @@ -22,11 +24,14 @@ class ConfigSingleton: if not ConfigSingleton._config: raise RuntimeError("Global configuration not yet initialized.") if not hasattr(ConfigSingleton._config, name): - raise RuntimeError("Configuration option {} not found in config.".format(name)) + raise RuntimeError( + "Configuration option {} not found in config.".format(name) + ) return ConfigSingleton._config[name] -Config = ConfigSingleton() # pylint: disable=invalid-name +Config = ConfigSingleton() # pylint: disable=invalid-name + def initialize_globals(): c = AttrDict() @@ -34,16 +39,22 @@ def initialize_globals(): # Augmentations c.augmentations = parse_augmentations(FLAGS.augment) if c.augmentations and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0: - log_warn('Due to current feature-cache settings the exact same sample augmentations of the first ' - 'epoch will be repeated on all following epochs. This could lead to unintended over-fitting. ' - 'You could use --cache_for_epochs to invalidate the cache after a given number of epochs.') + log_warn( + "Due to current feature-cache settings the exact same sample augmentations of the first " + "epoch will be repeated on all following epochs. This could lead to unintended over-fitting. " + "You could use --cache_for_epochs to invalidate the cache after a given number of epochs." + ) if FLAGS.normalize_sample_rate: - c.augmentations = [NormalizeSampleRate(FLAGS.audio_sample_rate)] + c['augmentations'] + c.augmentations = [NormalizeSampleRate(FLAGS.audio_sample_rate)] + c[ + "augmentations" + ] # Caching if FLAGS.cache_for_epochs == 1: - log_warn('--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it.') + log_warn( + "--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it." + ) # Read-buffer FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer) @@ -58,26 +69,29 @@ def initialize_globals(): # Set default checkpoint dir if not FLAGS.checkpoint_dir: - FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('stt', 'checkpoints')) + FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints")) - if FLAGS.load_train not in ['last', 'best', 'init', 'auto']: - FLAGS.load_train = 'auto' + if FLAGS.load_train not in ["last", "best", "init", "auto"]: + FLAGS.load_train = "auto" - if FLAGS.load_evaluate not in ['last', 'best', 'auto']: - FLAGS.load_evaluate = 'auto' + if FLAGS.load_evaluate not in ["last", "best", "auto"]: + FLAGS.load_evaluate = "auto" # Set default summary dir if not FLAGS.summary_dir: - FLAGS.summary_dir = xdg.save_data_path(os.path.join('stt', 'summaries')) + FLAGS.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries")) # Standard session configuration that'll be used for all new sessions. - c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, - inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, - intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, - gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth)) + c.session_config = tfv1.ConfigProto( + allow_soft_placement=True, + log_device_placement=FLAGS.log_placement, + inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, + intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, + gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth), + ) # CPU device - c.cpu_device = '/cpu:0' + c.cpu_device = "/cpu:0" # Available GPU devices c.available_devices = get_available_gpus(c.session_config) @@ -98,10 +112,10 @@ def initialize_globals(): # doc/Geometry.md # Number of MFCC features - c.n_input = 26 # TODO: Determine this programmatically from the sample rate + c.n_input = 26 # TODO: Determine this programmatically from the sample rate # The number of frames in the context - c.n_context = 9 # TODO: Determine the optimal value using a validation data set + c.n_context = 9 # TODO: Determine the optimal value using a validation data set # Number of units in hidden layers c.n_hidden = FLAGS.n_hidden @@ -119,40 +133,54 @@ def initialize_globals(): c.n_hidden_3 = c.n_cell_dim # Units in the sixth layer = number of characters in the target language plus one - c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label + c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label # Size of audio window in samples if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0: - log_error('--feature_win_len value ({}) in milliseconds ({}) multiplied ' - 'by --audio_sample_rate value ({}) must be an integer value. Adjust ' - 'your --feature_win_len value or resample your audio accordingly.' - ''.format(FLAGS.feature_win_len, FLAGS.feature_win_len / 1000, FLAGS.audio_sample_rate)) + log_error( + "--feature_win_len value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_len value or resample your audio accordingly." + "".format( + FLAGS.feature_win_len, + FLAGS.feature_win_len / 1000, + FLAGS.audio_sample_rate, + ) + ) sys.exit(1) c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000) # Stride for feature computations in samples if (FLAGS.feature_win_step * FLAGS.audio_sample_rate) % 1000 != 0: - log_error('--feature_win_step value ({}) in milliseconds ({}) multiplied ' - 'by --audio_sample_rate value ({}) must be an integer value. Adjust ' - 'your --feature_win_step value or resample your audio accordingly.' - ''.format(FLAGS.feature_win_step, FLAGS.feature_win_step / 1000, FLAGS.audio_sample_rate)) + log_error( + "--feature_win_step value ({}) in milliseconds ({}) multiplied " + "by --audio_sample_rate value ({}) must be an integer value. Adjust " + "your --feature_win_step value or resample your audio accordingly." + "".format( + FLAGS.feature_win_step, + FLAGS.feature_win_step / 1000, + FLAGS.audio_sample_rate, + ) + ) sys.exit(1) c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000) if FLAGS.one_shot_infer: if not path_exists_remote(FLAGS.one_shot_infer): - log_error('Path specified in --one_shot_infer is not a valid file.') + log_error("Path specified in --one_shot_infer is not a valid file.") sys.exit(1) if FLAGS.train_cudnn and FLAGS.load_cudnn: - log_error('Trying to use --train_cudnn, but --load_cudnn ' - 'was also specified. The --load_cudnn flag is only ' - 'needed when converting a CuDNN RNN checkpoint to ' - 'a CPU-capable graph. If your system is capable of ' - 'using CuDNN RNN, you can just specify the CuDNN RNN ' - 'checkpoint normally with --save_checkpoint_dir.') + log_error( + "Trying to use --train_cudnn, but --load_cudnn " + "was also specified. The --load_cudnn flag is only " + "needed when converting a CuDNN RNN checkpoint to " + "a CPU-capable graph. If your system is capable of " + "using CuDNN RNN, you can just specify the CuDNN RNN " + "checkpoint normally with --save_checkpoint_dir." + ) sys.exit(1) # If separate save and load flags were not specified, default to load and save @@ -163,4 +191,4 @@ def initialize_globals(): if not FLAGS.load_checkpoint_dir: FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir - ConfigSingleton._config = c # pylint: disable=protected-access + ConfigSingleton._config = c # pylint: disable=protected-access diff --git a/training/coqui_stt_training/util/downloader.py b/training/coqui_stt_training/util/downloader.py index c527eb9b..f559fb58 100644 --- a/training/coqui_stt_training/util/downloader.py +++ b/training/coqui_stt_training/util/downloader.py @@ -1,10 +1,18 @@ -import requests +from os import makedirs, path + import progressbar +import requests -from os import path, makedirs -from .io import open_remote, path_exists_remote, is_remote_path +from .io import is_remote_path, open_remote, path_exists_remote + +SIMPLE_BAR = [ + "Progress ", + progressbar.Bar(), + " ", + progressbar.Percentage(), + " completed", +] -SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed'] def maybe_download(archive_name, target_dir, archive_url): # If archive file does not exist, download it... @@ -17,12 +25,15 @@ def maybe_download(archive_name, target_dir, archive_url): if not path_exists_remote(archive_path): print('No archive "%s" - downloading...' % archive_path) req = requests.get(archive_url, stream=True) - total_size = int(req.headers.get('content-length', 0)) + total_size = int(req.headers.get("content-length", 0)) done = 0 - with open_remote(archive_path, 'wb') as f: - bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR) + with open_remote(archive_path, "wb") as f: + bar = progressbar.ProgressBar( + max_value=total_size if total_size > 0 else progressbar.UnknownLength, + widgets=SIMPLE_BAR, + ) - for data in req.iter_content(1024*1024): + for data in req.iter_content(1024 * 1024): done += len(data) f.write(data) bar.update(done) diff --git a/training/coqui_stt_training/util/evaluate_tools.py b/training/coqui_stt_training/util/evaluate_tools.py index 68d29f3e..62c78d38 100644 --- a/training/coqui_stt_training/util/evaluate_tools.py +++ b/training/coqui_stt_training/util/evaluate_tools.py @@ -9,8 +9,9 @@ import numpy as np from attrdict import AttrDict from .flags import FLAGS -from .text import levenshtein from .io import open_remote +from .text import levenshtein + def pmap(fun, iterable): pool = Pool() @@ -42,26 +43,28 @@ def process_decode_result(item): char_length = len(ground_truth) word_distance = levenshtein(ground_truth.split(), prediction.split()) word_length = len(ground_truth.split()) - return AttrDict({ - 'wav_filename': wav_filename, - 'src': ground_truth, - 'res': prediction, - 'loss': loss, - 'char_distance': char_distance, - 'char_length': char_length, - 'word_distance': word_distance, - 'word_length': word_length, - 'cer': char_distance / char_length, - 'wer': word_distance / word_length, - }) + return AttrDict( + { + "wav_filename": wav_filename, + "src": ground_truth, + "res": prediction, + "loss": loss, + "char_distance": char_distance, + "char_length": char_length, + "word_distance": word_distance, + "word_length": word_length, + "cer": char_distance / char_length, + "wer": word_distance / word_length, + } + ) def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name): - r''' + r""" This routine will calculate and print a WER report. It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). - ''' + """ samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses)) # Getting the WER and CER from the accumulated edit distances and lengths @@ -88,41 +91,43 @@ def print_report(samples, losses, wer, cer, dataset_name): # Print summary mean_loss = np.mean(losses) - print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss)) - print('-' * 80) + print( + "Test on %s - WER: %f, CER: %f, loss: %f" % (dataset_name, wer, cer, mean_loss) + ) + print("-" * 80) - best_samples = samples[:FLAGS.report_count] - worst_samples = samples[-FLAGS.report_count:] + best_samples = samples[: FLAGS.report_count] + worst_samples = samples[-FLAGS.report_count :] median_index = int(len(samples) / 2) median_left = int(FLAGS.report_count / 2) median_right = FLAGS.report_count - median_left - median_samples = samples[median_index - median_left:median_index + median_right] + median_samples = samples[median_index - median_left : median_index + median_right] def print_single_sample(sample): - print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) - print(' - wav: file://%s' % sample.wav_filename) + print("WER: %f, CER: %f, loss: %f" % (sample.wer, sample.cer, sample.loss)) + print(" - wav: file://%s" % sample.wav_filename) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) - print('-' * 80) + print("-" * 80) - print('Best WER:', '\n' + '-' * 80) + print("Best WER:", "\n" + "-" * 80) for s in best_samples: print_single_sample(s) - print('Median WER:', '\n' + '-' * 80) + print("Median WER:", "\n" + "-" * 80) for s in median_samples: print_single_sample(s) - print('Worst WER:', '\n' + '-' * 80) + print("Worst WER:", "\n" + "-" * 80) for s in worst_samples: print_single_sample(s) def save_samples_json(samples, output_path): - ''' Save decoded tuples as JSON, converting NumPy floats to Python floats. + """Save decoded tuples as JSON, converting NumPy floats to Python floats. - We set ensure_ascii=True to prevent json from escaping non-ASCII chars - in the texts. - ''' - with open_remote(output_path, 'w') as fout: + We set ensure_ascii=True to prevent json from escaping non-ASCII chars + in the texts. + """ + with open_remote(output_path, "w") as fout: json.dump(samples, fout, default=float, ensure_ascii=False, indent=2) diff --git a/training/coqui_stt_training/util/feeding.py b/training/coqui_stt_training/util/feeding.py index 30a2b2f4..3cb1edb0 100644 --- a/training/coqui_stt_training/util/feeding.py +++ b/training/coqui_stt_training/util/feeding.py @@ -5,117 +5,175 @@ from collections import Counter from functools import partial import numpy as np -import tensorflow as tf +import tensorflow as tf from tensorflow.python.ops import gen_audio_ops as contrib_audio +from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split +from .augmentations import apply_graph_augmentations, apply_sample_augmentations from .config import Config -from .text import text_to_char_array from .flags import FLAGS -from .augmentations import apply_sample_augmentations, apply_graph_augmentations -from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT +from .helpers import MEGABYTE, remember_exception from .sample_collections import samples_from_sources -from .helpers import remember_exception, MEGABYTE +from .text import text_to_char_array -def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phase=False, augmentations=None, sample_id=None): +def audio_to_features( + audio, + sample_rate, + transcript=None, + clock=0.0, + train_phase=False, + augmentations=None, + sample_id=None, +): if train_phase: # We need the lambdas to make TensorFlow happy. # pylint: disable=unnecessary-lambda - tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate), - lambda: tf.print('WARNING: sample rate of sample', sample_id, '(', sample_rate, ') ' - 'does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'), - lambda: tf.no_op(), - name='matching_sample_rate') + tf.cond( + tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate), + lambda: tf.print( + "WARNING: sample rate of sample", + sample_id, + "(", + sample_rate, + ") " + "does not match FLAGS.audio_sample_rate. This can lead to incorrect results.", + ), + lambda: tf.no_op(), + name="matching_sample_rate", + ) if train_phase and augmentations: - audio = apply_graph_augmentations('signal', audio, augmentations, transcript=transcript, clock=clock) + audio = apply_graph_augmentations( + "signal", audio, augmentations, transcript=transcript, clock=clock + ) - spectrogram = contrib_audio.audio_spectrogram(audio, - window_size=Config.audio_window_samples, - stride=Config.audio_step_samples, - magnitude_squared=True) + spectrogram = contrib_audio.audio_spectrogram( + audio, + window_size=Config.audio_window_samples, + stride=Config.audio_step_samples, + magnitude_squared=True, + ) if train_phase and augmentations: - spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, transcript=transcript, clock=clock) + spectrogram = apply_graph_augmentations( + "spectrogram", + spectrogram, + augmentations, + transcript=transcript, + clock=clock, + ) - features = contrib_audio.mfcc(spectrogram=spectrogram, - sample_rate=sample_rate, - dct_coefficient_count=Config.n_input, - upper_frequency_limit=FLAGS.audio_sample_rate / 2) + features = contrib_audio.mfcc( + spectrogram=spectrogram, + sample_rate=sample_rate, + dct_coefficient_count=Config.n_input, + upper_frequency_limit=FLAGS.audio_sample_rate / 2, + ) features = tf.reshape(features, [-1, Config.n_input]) if train_phase and augmentations: - features = apply_graph_augmentations('features', features, augmentations, transcript=transcript, clock=clock) + features = apply_graph_augmentations( + "features", features, augmentations, transcript=transcript, clock=clock + ) return features, tf.shape(input=features)[0] -def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentations=None): +def audiofile_to_features( + wav_filename, clock=0.0, train_phase=False, augmentations=None +): samples = tf.io.read_file(wav_filename) decoded = contrib_audio.decode_wav(samples, desired_channels=1) - return audio_to_features(decoded.audio, - decoded.sample_rate, - clock=clock, - train_phase=train_phase, - augmentations=augmentations, - sample_id=wav_filename) + return audio_to_features( + decoded.audio, + decoded.sample_rate, + clock=clock, + train_phase=train_phase, + augmentations=augmentations, + sample_id=wav_filename, + ) -def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None): +def entry_to_features( + sample_id, + audio, + sample_rate, + transcript, + clock, + train_phase=False, + augmentations=None, +): # https://bugs.python.org/issue32117 sparse_transcript = tf.SparseTensor(*transcript) - features, features_len = audio_to_features(audio, - sample_rate, - transcript=sparse_transcript, - clock=clock, - train_phase=train_phase, - augmentations=augmentations, - sample_id=sample_id) + features, features_len = audio_to_features( + audio, + sample_rate, + transcript=sparse_transcript, + clock=clock, + train_phase=train_phase, + augmentations=augmentations, + sample_id=sample_id, + ) return sample_id, features, features_len, sparse_transcript def to_sparse_tuple(sequence): r"""Creates a sparse representention of ``sequence``. - Returns a tuple with (indices, values, shape) + Returns a tuple with (indices, values, shape) """ - indices = np.asarray(list(zip([0]*len(sequence), range(len(sequence)))), dtype=np.int64) + indices = np.asarray( + list(zip([0] * len(sequence), range(len(sequence)))), dtype=np.int64 + ) shape = np.asarray([1, len(sequence)], dtype=np.int64) return indices, sequence, shape -def create_dataset(sources, - batch_size, - epochs=1, - augmentations=None, - cache_path=None, - train_phase=False, - reverse=False, - limit=0, - exception_box=None, - process_ahead=None, - buffering=1 * MEGABYTE): +def create_dataset( + sources, + batch_size, + epochs=1, + augmentations=None, + cache_path=None, + train_phase=False, + reverse=False, + limit=0, + exception_box=None, + process_ahead=None, + buffering=1 * MEGABYTE, +): epoch_counter = Counter() # survives restarts of the dataset and its generator def generate_values(): - epoch = epoch_counter['epoch'] + epoch = epoch_counter["epoch"] if train_phase: - epoch_counter['epoch'] += 1 - samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse) + epoch_counter["epoch"] += 1 + samples = samples_from_sources( + sources, buffering=buffering, labeled=True, reverse=reverse + ) num_samples = len(samples) if limit > 0: num_samples = min(limit, num_samples) - samples = apply_sample_augmentations(samples, - augmentations, - buffering=buffering, - process_ahead=2 * batch_size if process_ahead is None else process_ahead, - clock=epoch / epochs, - final_clock=(epoch + 1) / epochs) + samples = apply_sample_augmentations( + samples, + augmentations, + buffering=buffering, + process_ahead=2 * batch_size if process_ahead is None else process_ahead, + clock=epoch / epochs, + final_clock=(epoch + 1) / epochs, + ) for sample_index, sample in enumerate(samples): if sample_index >= num_samples: break - clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0 - transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id) + clock = ( + (epoch * num_samples + sample_index) / (epochs * num_samples) + if train_phase and epochs > 0 + else 0.0 + ) + transcript = text_to_char_array( + sample.transcript, Config.alphabet, context=sample.sample_id + ) transcript = to_sparse_tuple(transcript) yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript, clock @@ -128,31 +186,46 @@ def create_dataset(sources, def batch_fn(sample_ids, features, features_len, transcripts): features = tf.data.Dataset.zip((features, features_len)) - features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], [])) + features = features.padded_batch( + batch_size, padded_shapes=([None, Config.n_input], []) + ) transcripts = transcripts.batch(batch_size).map(sparse_reshape) sample_ids = sample_ids.batch(batch_size) return tf.data.Dataset.zip((sample_ids, features, transcripts)) - process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations) + process_fn = partial( + entry_to_features, train_phase=train_phase, augmentations=augmentations + ) - dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), - output_types=(tf.string, tf.float32, tf.int32, - (tf.int64, tf.int32, tf.int64), tf.float64)) - .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) + dataset = tf.data.Dataset.from_generator( + remember_exception(generate_values, exception_box), + output_types=( + tf.string, + tf.float32, + tf.int32, + (tf.int64, tf.int32, tf.int64), + tf.float64, + ), + ).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) if cache_path: dataset = dataset.cache(cache_path) - dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn) - .prefetch(len(Config.available_devices))) + dataset = ( + dataset.window(batch_size, drop_remainder=train_phase) + .flat_map(batch_fn) + .prefetch(len(Config.available_devices)) + ) return dataset -def split_audio_file(audio_path, - audio_format=DEFAULT_FORMAT, - batch_size=1, - aggressiveness=3, - outlier_duration_ms=10000, - outlier_batch_size=1, - exception_box=None): +def split_audio_file( + audio_path, + audio_format=DEFAULT_FORMAT, + batch_size=1, + aggressiveness=3, + outlier_duration_ms=10000, + outlier_batch_size=1, + exception_box=None, +): def generate_values(): frames = read_frames_from_file(audio_path) segments = vad_split(frames, aggressiveness=aggressiveness) @@ -166,17 +239,23 @@ def split_audio_file(audio_path, return time_start, time_end, features, features_len def create_batch_set(bs, criteria): - return (tf.data.Dataset - .from_generator(remember_exception(generate_values, exception_box), - output_types=(tf.int32, tf.int32, tf.float32)) - .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) - .filter(criteria) - .padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))) + return ( + tf.data.Dataset.from_generator( + remember_exception(generate_values, exception_box), + output_types=(tf.int32, tf.int32, tf.float32), + ) + .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) + .filter(criteria) + .padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], [])) + ) - nds = create_batch_set(batch_size, - lambda start, end, f, fl: end - start <= int(outlier_duration_ms)) - ods = create_batch_set(outlier_batch_size, - lambda start, end, f, fl: end - start > int(outlier_duration_ms)) + nds = create_batch_set( + batch_size, lambda start, end, f, fl: end - start <= int(outlier_duration_ms) + ) + ods = create_batch_set( + outlier_batch_size, + lambda start, end, f, fl: end - start > int(outlier_duration_ms), + ) dataset = nds.concatenate(ods) dataset = dataset.prefetch(len(Config.available_devices)) return dataset diff --git a/training/coqui_stt_training/util/flags.py b/training/coqui_stt_training/util/flags.py index 179a987d..7fb84ffe 100644 --- a/training/coqui_stt_training/util/flags.py +++ b/training/coqui_stt_training/util/flags.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import os + import absl.flags FLAGS = absl.flags.FLAGS @@ -12,179 +13,448 @@ def create_flags(): f = absl.flags - f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.') - f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run.') - f.DEFINE_string('test_files', '', 'comma separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested.') - f.DEFINE_string('metrics_files', '', 'comma separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed.') + f.DEFINE_string( + "train_files", + "", + "comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.", + ) + f.DEFINE_string( + "dev_files", + "", + "comma separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run.", + ) + f.DEFINE_string( + "test_files", + "", + "comma separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested.", + ) + f.DEFINE_string( + "metrics_files", + "", + "comma separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed.", + ) - f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)') - f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.') - f.DEFINE_integer('cache_for_epochs', 0, 'after how many epochs the feature cache is invalidated again - 0 for "never"') + f.DEFINE_string( + "read_buffer", + "1MB", + "buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)", + ) + f.DEFINE_string( + "feature_cache", + "", + "cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.", + ) + f.DEFINE_integer( + "cache_for_epochs", + 0, + 'after how many epochs the feature cache is invalidated again - 0 for "never"', + ) - f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds') - f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds') - f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model') - f.DEFINE_boolean('normalize_sample_rate', True, 'normalize sample rate of all train_files to --audio_sample_rate') + f.DEFINE_integer( + "feature_win_len", 32, "feature extraction audio window length in milliseconds" + ) + f.DEFINE_integer( + "feature_win_step", 20, "feature extraction window step length in milliseconds" + ) + f.DEFINE_integer("audio_sample_rate", 16000, "sample rate value expected by model") + f.DEFINE_boolean( + "normalize_sample_rate", + True, + "normalize sample rate of all train_files to --audio_sample_rate", + ) # Data Augmentation # ================ - f.DEFINE_multi_string('augment', None, 'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"') + f.DEFINE_multi_string( + "augment", + None, + 'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"', + ) # Global Constants # ================ - f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for') + f.DEFINE_integer( + "epochs", + 75, + "how many epochs (complete runs through the train files) to train for", + ) - f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers') - f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate') - f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate') - f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0') - f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0') - f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate') + f.DEFINE_float("dropout_rate", 0.05, "dropout rate for feedforward layers") + f.DEFINE_float( + "dropout_rate2", -1.0, "dropout rate for layer 2 - defaults to dropout_rate" + ) + f.DEFINE_float( + "dropout_rate3", -1.0, "dropout rate for layer 3 - defaults to dropout_rate" + ) + f.DEFINE_float("dropout_rate4", 0.0, "dropout rate for layer 4 - defaults to 0.0") + f.DEFINE_float("dropout_rate5", 0.0, "dropout rate for layer 5 - defaults to 0.0") + f.DEFINE_float( + "dropout_rate6", -1.0, "dropout rate for layer 6 - defaults to dropout_rate" + ) - f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers') + f.DEFINE_float("relu_clip", 20.0, "ReLU clipping value for non-recurrent layers") # Adam optimizer(http://arxiv.org/abs/1412.6980) parameters - f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer') - f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer') - f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer') - f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer') + f.DEFINE_float("beta1", 0.9, "beta 1 parameter of Adam optimizer") + f.DEFINE_float("beta2", 0.999, "beta 2 parameter of Adam optimizer") + f.DEFINE_float("epsilon", 1e-8, "epsilon parameter of Adam optimizer") + f.DEFINE_float("learning_rate", 0.001, "learning rate of Adam optimizer") # Batch sizes - f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch') - f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch') - f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch') + f.DEFINE_integer("train_batch_size", 1, "number of elements in a training batch") + f.DEFINE_integer("dev_batch_size", 1, "number of elements in a validation batch") + f.DEFINE_integer("test_batch_size", 1, "number of elements in a test batch") - f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph') + f.DEFINE_integer( + "export_batch_size", 1, "number of elements per batch on the exported graph" + ) # Performance - f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') - f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') - f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory') - f.DEFINE_boolean('load_cudnn', False, 'Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.') - f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') - f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.') + f.DEFINE_integer( + "inter_op_parallelism_threads", + 0, + "number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED", + ) + f.DEFINE_integer( + "intra_op_parallelism_threads", + 0, + "number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED", + ) + f.DEFINE_boolean( + "use_allow_growth", + False, + "use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory", + ) + f.DEFINE_boolean( + "load_cudnn", + False, + "Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.", + ) + f.DEFINE_boolean( + "train_cudnn", + False, + "use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work", + ) + f.DEFINE_boolean( + "automatic_mixed_precision", + False, + "whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.", + ) # Sample limits - f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') - f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit') - f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit') + f.DEFINE_integer( + "limit_train", + 0, + "maximum number of elements to use from train set - 0 means no limit", + ) + f.DEFINE_integer( + "limit_dev", + 0, + "maximum number of elements to use from validation set - 0 means no limit", + ) + f.DEFINE_integer( + "limit_test", + 0, + "maximum number of elements to use from test set - 0 means no limit", + ) # Sample order - f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set') - f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set') - f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set') + f.DEFINE_boolean( + "reverse_train", False, "if to reverse sample order of the train set" + ) + f.DEFINE_boolean("reverse_dev", False, "if to reverse sample order of the dev set") + f.DEFINE_boolean( + "reverse_test", False, "if to reverse sample order of the test set" + ) # Checkpointing - f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') - f.DEFINE_string('load_checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') - f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') - f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds') - f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5') - f.DEFINE_string('load_train', 'auto', 'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.') - f.DEFINE_string('load_evaluate', 'auto', 'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.') + f.DEFINE_string( + "checkpoint_dir", + "", + 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification', + ) + f.DEFINE_string( + "load_checkpoint_dir", + "", + 'directory in which checkpoints are stored - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification', + ) + f.DEFINE_string( + "save_checkpoint_dir", + "", + 'directory to which checkpoints are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification', + ) + f.DEFINE_integer("checkpoint_secs", 600, "checkpoint saving interval in seconds") + f.DEFINE_integer( + "max_to_keep", 5, "number of checkpoint files to keep - default value is 5" + ) + f.DEFINE_string( + "load_train", + "auto", + 'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.', + ) + f.DEFINE_string( + "load_evaluate", + "auto", + 'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.', + ) # Transfer Learning - f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)') + f.DEFINE_integer( + "drop_source_layers", + 0, + "single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)", + ) # Exporting - f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported') - f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models') - f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine') - f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') - f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') - f.DEFINE_string('export_file_name', 'output_graph', 'name for the exported model file name') - f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph') + f.DEFINE_string( + "export_dir", + "", + "directory in which exported models are stored - if omitted, the model won't get exported", + ) + f.DEFINE_boolean("remove_export", False, "whether to remove old exported models") + f.DEFINE_boolean("export_tflite", False, "export a graph ready for TF Lite engine") + f.DEFINE_integer( + "n_steps", + 16, + "how many timesteps to process at once by the export graph, higher values mean more latency", + ) + f.DEFINE_boolean( + "export_zip", False, "export a TFLite model and package with LM and info.json" + ) + f.DEFINE_string( + "export_file_name", "output_graph", "name for the exported model file name" + ) + f.DEFINE_integer( + "export_beam_width", 500, "default beam width to embed into exported graph" + ) # Model metadata - f.DEFINE_string('export_author_id', 'author', 'author of the exported model. GitHub user or organization name used to uniquely identify the author of this model') - f.DEFINE_string('export_model_name', 'model', 'name of the exported model. Must not contain forward slashes.') - f.DEFINE_string('export_model_version', '0.0.1', 'semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with Coqui STT versions') + f.DEFINE_string( + "export_author_id", + "author", + "author of the exported model. GitHub user or organization name used to uniquely identify the author of this model", + ) + f.DEFINE_string( + "export_model_name", + "model", + "name of the exported model. Must not contain forward slashes.", + ) + f.DEFINE_string( + "export_model_version", + "0.0.1", + "semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with Coqui STT versions", + ) def str_val_equals_help(name, val_desc): - f.DEFINE_string(name, '<{}>'.format(val_desc), val_desc) + f.DEFINE_string(name, "<{}>".format(val_desc), val_desc) - str_val_equals_help('export_contact_info', 'public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors') - str_val_equals_help('export_license', 'SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.') - str_val_equals_help('export_language', 'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".') - str_val_equals_help('export_min_stt_version', 'minimum Coqui STT version (inclusive) the exported model is compatible with') - str_val_equals_help('export_max_stt_version', 'maximum Coqui STT version (inclusive) the exported model is compatible with') - str_val_equals_help('export_description', 'Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.') + str_val_equals_help( + "export_contact_info", + "public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors", + ) + str_val_equals_help( + "export_license", + "SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.", + ) + str_val_equals_help( + "export_language", + 'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".', + ) + str_val_equals_help( + "export_min_stt_version", + "minimum Coqui STT version (inclusive) the exported model is compatible with", + ) + str_val_equals_help( + "export_max_stt_version", + "maximum Coqui STT version (inclusive) the exported model is compatible with", + ) + str_val_equals_help( + "export_description", + "Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.", + ) # Reporting - f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR') - f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') + f.DEFINE_integer( + "log_level", + 1, + "log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR", + ) + f.DEFINE_boolean( + "show_progressbar", + True, + "Show progress for training, validation and testing processes. Log level should be > 0.", + ) - f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') - f.DEFINE_integer('report_count', 5, 'number of phrases for each of best WER, median WER and worst WER to print out during a WER report') + f.DEFINE_boolean( + "log_placement", + False, + "whether to log device placement of the operators to the console", + ) + f.DEFINE_integer( + "report_count", + 5, + "number of phrases for each of best WER, median WER and worst WER to print out during a WER report", + ) - f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "stt/summaries" within user\'s data home specified by the XDG Base Directory Specification') + f.DEFINE_string( + "summary_dir", + "", + 'target directory for TensorBoard summaries - defaults to directory "stt/summaries" within user\'s data home specified by the XDG Base Directory Specification', + ) - f.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples generated during a test epoch') + f.DEFINE_string( + "test_output_file", + "", + "path to a file to save all src/decoded/distance/loss tuples generated during a test epoch", + ) # Geometry - f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers') - f.DEFINE_boolean('layer_norm', False, 'wether to use layer-normalization after each fully-connected layer (except the last one)') + f.DEFINE_integer("n_hidden", 2048, "layer width to use when initialising layers") + f.DEFINE_boolean( + "layer_norm", + False, + "wether to use layer-normalization after each fully-connected layer (except the last one)", + ) # Initialization - f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables') + f.DEFINE_integer( + "random_seed", 4568, "default random seed that is used to initialize variables" + ) # Early Stopping - f.DEFINE_boolean('early_stop', False, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') - f.DEFINE_integer('es_epochs', 25, 'Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') - f.DEFINE_float('es_min_delta', 0.05, 'Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau') + f.DEFINE_boolean( + "early_stop", + False, + "Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.", + ) + f.DEFINE_integer( + "es_epochs", + 25, + "Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point", + ) + f.DEFINE_float( + "es_min_delta", + 0.05, + "Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau", + ) # Reduce learning rate on plateau - f.DEFINE_boolean('reduce_lr_on_plateau', False, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') - f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping') - f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.') - f.DEFINE_boolean('force_initialize_learning_rate', False, 'Force re-initialization of learning rate which was previously reduced.') + f.DEFINE_boolean( + "reduce_lr_on_plateau", + False, + "Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.", + ) + f.DEFINE_integer( + "plateau_epochs", + 10, + "Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping", + ) + f.DEFINE_float( + "plateau_reduction", + 0.1, + "Multiplicative factor to apply to the current learning rate if a plateau has occurred.", + ) + f.DEFINE_boolean( + "force_initialize_learning_rate", + False, + "Force re-initialization of learning rate which was previously reduced.", + ) # Decoder - f.DEFINE_boolean('bytes_output_mode', False, 'enable Bytes Output Mode mode. When this is used the model outputs UTF-8 byte values directly rather than using an alphabet mapping. The --alphabet_config_path option will be ignored. See the training documentation for more details.') - f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.') - f.DEFINE_string('scorer_path', '', 'path to the external scorer file.') - f.DEFINE_alias('scorer', 'scorer_path') - f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') - f.DEFINE_float('lm_alpha', 0.931289039105002, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') - f.DEFINE_float('lm_beta', 1.1834137581510284, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') - f.DEFINE_float('cutoff_prob', 1.0, 'only consider characters until this probability mass is reached. 1.0 = disabled.') - f.DEFINE_integer('cutoff_top_n', 300, 'only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.') + f.DEFINE_boolean( + "bytes_output_mode", + False, + "enable Bytes Output Mode mode. When this is used the model outputs UTF-8 byte values directly rather than using an alphabet mapping. The --alphabet_config_path option will be ignored. See the training documentation for more details.", + ) + f.DEFINE_string( + "alphabet_config_path", + "data/alphabet.txt", + "path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.", + ) + f.DEFINE_string("scorer_path", "", "path to the external scorer file.") + f.DEFINE_alias("scorer", "scorer_path") + f.DEFINE_integer( + "beam_width", + 1024, + "beam width used in the CTC decoder when building candidate transcriptions", + ) + f.DEFINE_float( + "lm_alpha", + 0.931289039105002, + "the alpha hyperparameter of the CTC decoder. Language Model weight.", + ) + f.DEFINE_float( + "lm_beta", + 1.1834137581510284, + "the beta hyperparameter of the CTC decoder. Word insertion weight.", + ) + f.DEFINE_float( + "cutoff_prob", + 1.0, + "only consider characters until this probability mass is reached. 1.0 = disabled.", + ) + f.DEFINE_integer( + "cutoff_top_n", + 300, + "only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.", + ) # Inference mode - f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.') + f.DEFINE_string( + "one_shot_infer", + "", + "one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.", + ) # Optimizer mode - f.DEFINE_float('lm_alpha_max', 5, 'the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.') - f.DEFINE_float('lm_beta_max', 5, 'the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.') - f.DEFINE_integer('n_trials', 2400, 'the number of trials to run during hyperparameter optimization.') + f.DEFINE_float( + "lm_alpha_max", + 5, + "the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.", + ) + f.DEFINE_float( + "lm_beta_max", + 5, + "the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.", + ) + f.DEFINE_integer( + "n_trials", + 2400, + "the number of trials to run during hyperparameter optimization.", + ) # Register validators for paths which require a file to be specified - f.register_validator('alphabet_config_path', - os.path.isfile, - message='The file pointed to by --alphabet_config_path must exist and be readable.') + f.register_validator( + "alphabet_config_path", + os.path.isfile, + message="The file pointed to by --alphabet_config_path must exist and be readable.", + ) + + f.register_validator( + "one_shot_infer", + lambda value: not value or os.path.isfile(value), + message="The file pointed to by --one_shot_infer must exist and be readable.", + ) - f.register_validator('one_shot_infer', - lambda value: not value or os.path.isfile(value), - message='The file pointed to by --one_shot_infer must exist and be readable.') # sphinx-doc: training_ref_flags_end diff --git a/training/coqui_stt_training/util/gpu.py b/training/coqui_stt_training/util/gpu.py index be9ac99f..aeb61597 100755 --- a/training/coqui_stt_training/util/gpu.py +++ b/training/coqui_stt_training/util/gpu.py @@ -6,4 +6,4 @@ def get_available_gpus(config): Returns the number of GPUs available on this system. """ local_device_protos = device_lib.list_local_devices(session_config=config) - return [x.name for x in local_device_protos if x.device_type == 'GPU'] + return [x.name for x in local_device_protos if x.device_type == "GPU"] diff --git a/training/coqui_stt_training/util/helpers.py b/training/coqui_stt_training/util/helpers.py index 5f434480..71749a74 100644 --- a/training/coqui_stt_training/util/helpers.py +++ b/training/coqui_stt_training/util/helpers.py @@ -1,21 +1,21 @@ +import heapq import os +import random import sys import time -import heapq -import semver -import random - -from multiprocessing import Pool from collections import namedtuple +from multiprocessing import Pool + +import semver KILO = 1024 KILOBYTE = 1 * KILO MEGABYTE = KILO * KILOBYTE GIGABYTE = KILO * MEGABYTE TERABYTE = KILO * GIGABYTE -SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE} +SIZE_PREFIX_LOOKUP = {"k": KILOBYTE, "m": MEGABYTE, "g": GIGABYTE, "t": TERABYTE} -ValueRange = namedtuple('ValueRange', 'start end r') +ValueRange = namedtuple("ValueRange", "start end r") def parse_file_size(file_size): @@ -23,39 +23,49 @@ def parse_file_size(file_size): if len(file_size) == 0: return 0 n = int(keep_only_digits(file_size)) - if file_size[-1] == 'b': + if file_size[-1] == "b": file_size = file_size[:-1] e = file_size[-1] return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n def keep_only_digits(txt): - return ''.join(filter(str.isdigit, txt)) + return "".join(filter(str.isdigit, txt)) def secs_to_hours(secs): hours, remainder = divmod(secs, 3600) minutes, seconds = divmod(remainder, 60) - return '%d:%02d:%02d' % (hours, minutes, seconds) + return "%d:%02d:%02d" % (hours, minutes, seconds) def check_ctcdecoder_version(): - ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip() + ds_version_s = ( + open(os.path.join(os.path.dirname(__file__), "../VERSION")).read().strip() + ) try: # pylint: disable=import-outside-toplevel from coqui_stt_ctcdecoder import __version__ as decoder_version except ImportError as e: - if e.msg.find('__version__') > 0: - print("Coqui STT version ({ds_version}) requires CTC decoder to expose __version__. " - "Please upgrade the coqui_stt_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) + if e.msg.find("__version__") > 0: + print( + "Coqui STT version ({ds_version}) requires CTC decoder to expose __version__. " + "Please upgrade the coqui_stt_ctcdecoder package to version {ds_version}".format( + ds_version=ds_version_s + ) + ) sys.exit(1) raise e rv = semver.compare(ds_version_s, decoder_version) if rv != 0: - print("Coqui STT version ({}) and CTC decoder version ({}) do not match. " - "Please ensure matching versions are in use.".format(ds_version_s, decoder_version)) + print( + "Coqui STT version ({}) and CTC decoder version ({}) do not match. " + "Please ensure matching versions are in use.".format( + ds_version_s, decoder_version + ) + ) sys.exit(1) return rv @@ -65,6 +75,7 @@ class Interleaved: """Collection that lazily combines sorted collections in an interleaving fashion. During iteration the next smallest element from all the sorted collections is always picked. The collections must support iter() and len().""" + def __init__(self, *iterables, key=lambda obj: obj, reverse=False): self.iterables = iterables self.key = key @@ -83,6 +94,7 @@ class LenMap: Wrapper around python map() output object that preserves the original collection length by implementing __len__. """ + def __init__(self, fn, iterable): try: self.length = len(iterable) @@ -108,11 +120,21 @@ class LimitingPool: """Limits unbound ahead-processing of multiprocessing.Pool's imap method before items get consumed by the iteration caller. This prevents OOM issues in situations where items represent larger memory allocations.""" - def __init__(self, processes=None, initializer=None, initargs=None, process_ahead=None, sleeping_for=0.1): + + def __init__( + self, + processes=None, + initializer=None, + initargs=None, + process_ahead=None, + sleeping_for=0.1, + ): self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead self.sleeping_for = sleeping_for self.processed = 0 - self.pool = Pool(processes=processes, initializer=initializer, initargs=initargs) + self.pool = Pool( + processes=processes, initializer=initializer, initargs=initargs + ) def __enter__(self): return self @@ -139,6 +161,7 @@ class LimitingPool: class ExceptionBox: """Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator. Used in conjunction with `remember_exception`.""" + def __init__(self): self.exception = None @@ -152,6 +175,7 @@ class ExceptionBox: def remember_exception(iterable, exception_box=None): """Wraps a TensorFlow dataset generator for catching its actual exceptions that would otherwise just interrupt iteration w/o bubbling up.""" + def do_iterate(): try: yield from iterable() @@ -159,6 +183,7 @@ def remember_exception(iterable, exception_box=None): return except Exception as ex: # pylint: disable = broad-except exception_box.exception = ex + return iterable if exception_box is None else do_iterate @@ -174,30 +199,34 @@ def get_value_range(value, target_type): Any "missing" values are filled so that ValueRange always includes [start,end,r]. """ if isinstance(value, str): - if '~' in value: - parts = value.split('~') + if "~" in value: + parts = value.split("~") if len(parts) != 2: - raise ValueError('Cannot parse value range') + raise ValueError("Cannot parse value range") value = parts[0] r = parts[1] else: - r = 0 # if no supplied, use 0 - parts = value.split(':') + r = 0 # if no supplied, use 0 + parts = value.split(":") if len(parts) == 1: - parts.append(parts[0]) # only one given, so double it + parts.append(parts[0]) # only one given, so double it if len(parts) != 2: - raise ValueError('Cannot parse value range') + raise ValueError("Cannot parse value range") return ValueRange(target_type(parts[0]), target_type(parts[1]), target_type(r)) if isinstance(value, tuple): if len(value) == 2: - return ValueRange(target_type(value[0]), target_type(value[1]), target_type(0)) + return ValueRange( + target_type(value[0]), target_type(value[1]), target_type(0) + ) if len(value) == 3: - return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[2])) + return ValueRange( + target_type(value[0]), target_type(value[1]), target_type(value[2]) + ) else: - raise ValueError('Cannot convert to ValueRange: Wrong tuple size') + raise ValueError("Cannot convert to ValueRange: Wrong tuple size") if isinstance(value, int) or isinstance(value, float): return ValueRange(target_type(value), target_type(value), target_type(0)) - raise ValueError('Cannot convert to ValueRange: Wrong tuple size') + raise ValueError("Cannot convert to ValueRange: Wrong tuple size") def int_range(value): @@ -217,20 +246,25 @@ def pick_value_from_range(value_range, clock=None): def tf_pick_value_from_range(value_range, clock=None, double_precision=False): import tensorflow as tf # pylint: disable=import-outside-toplevel + if clock is None: clock = tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64) else: - clock = tf.maximum(tf.constant(0.0, dtype=tf.float64), - tf.minimum(tf.constant(1.0, dtype=tf.float64), clock)) + clock = tf.maximum( + tf.constant(0.0, dtype=tf.float64), + tf.minimum(tf.constant(1.0, dtype=tf.float64), clock), + ) value = value_range.start + clock * (value_range.end - value_range.start) if value_range.r: # if the option (~, randomization radius) is supplied, # sample the value from a uniform distribution with "radius" - value = tf.random.stateless_uniform([], - minval=value - value_range.r, - maxval=value + value_range.r, - seed=(clock * tf.int32.min, clock * tf.int32.max), - dtype=tf.float64) + value = tf.random.stateless_uniform( + [], + minval=value - value_range.r, + maxval=value + value_range.r, + seed=(clock * tf.int32.min, clock * tf.int32.max), + dtype=tf.float64, + ) if isinstance(value_range.start, int): return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32) return tf.cast(value, tf.float64 if double_precision else tf.float32) diff --git a/training/coqui_stt_training/util/importers.py b/training/coqui_stt_training/util/importers.py index 61f2342d..caf373db 100644 --- a/training/coqui_stt_training/util/importers.py +++ b/training/coqui_stt_training/util/importers.py @@ -3,33 +3,72 @@ import importlib import os import re import sys - -from .helpers import secs_to_hours from collections import Counter +from .helpers import secs_to_hours + + def get_counter(): - return Counter({'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'imported_time': 0, 'total_time': 0}) + return Counter( + { + "all": 0, + "failed": 0, + "invalid_label": 0, + "too_short": 0, + "too_long": 0, + "imported_time": 0, + "total_time": 0, + } + ) + def get_imported_samples(counter): - return counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'] - counter['invalid_label'] + return ( + counter["all"] + - counter["failed"] + - counter["too_short"] + - counter["too_long"] + - counter["invalid_label"] + ) + def print_import_report(counter, sample_rate, max_secs): - print('Imported %d samples.' % (get_imported_samples(counter))) - if counter['failed'] > 0: - print('Skipped %d samples that failed upon conversion.' % counter['failed']) - if counter['invalid_label'] > 0: - print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) - if counter['too_short'] > 0: - print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) - if counter['too_long'] > 0: - print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], max_secs)) - print('Final amount of imported audio: %s from %s.' % (secs_to_hours(counter['imported_time'] / sample_rate), secs_to_hours(counter['total_time'] / sample_rate))) + print("Imported %d samples." % (get_imported_samples(counter))) + if counter["failed"] > 0: + print("Skipped %d samples that failed upon conversion." % counter["failed"]) + if counter["invalid_label"] > 0: + print( + "Skipped %d samples that failed on transcript validation." + % counter["invalid_label"] + ) + if counter["too_short"] > 0: + print( + "Skipped %d samples that were too short to match the transcript." + % counter["too_short"] + ) + if counter["too_long"] > 0: + print( + "Skipped %d samples that were longer than %d seconds." + % (counter["too_long"], max_secs) + ) + print( + "Final amount of imported audio: %s from %s." + % ( + secs_to_hours(counter["imported_time"] / sample_rate), + secs_to_hours(counter["total_time"] / sample_rate), + ) + ) + def get_importers_parser(description): parser = argparse.ArgumentParser(description=description) - parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.') + parser.add_argument( + "--validate_label_locale", + help="Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE's DIRECTORY INTO PYTHONPATH.", + ) return parser + def get_validate_label(args): """ Expects an argparse.Namespace argument to search for validate_label_locale parameter. @@ -43,19 +82,22 @@ def get_validate_label(args): :type: function """ # Python 3.5 does not support passing a pathlib.Path to os.path.* methods - if 'validate_label_locale' not in args or (args.validate_label_locale is None): - print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.') + if "validate_label_locale" not in args or (args.validate_label_locale is None): + print( + "WARNING: No --validate_label_locale specified, your might end with inconsistent dataset." + ) return validate_label_eng validate_label_locale = str(args.validate_label_locale) if not os.path.exists(os.path.abspath(validate_label_locale)): - print('ERROR: Inexistent --validate_label_locale specified. Please check.') + print("ERROR: Inexistent --validate_label_locale specified. Please check.") return None module_dir = os.path.abspath(os.path.dirname(validate_label_locale)) sys.path.insert(1, module_dir) - fname = os.path.basename(validate_label_locale).replace('.py', '') + fname = os.path.basename(validate_label_locale).replace(".py", "") locale_module = importlib.import_module(fname, package=None) return locale_module.validate_label + # Validate and normalize transcriptions. Returns a cleaned version of the label # or None if it's invalid. def validate_label_eng(label): @@ -72,7 +114,7 @@ def validate_label_eng(label): label = label.replace("?", "") label = label.replace("!", "") label = label.replace(":", "") - label = label.replace("\"", "") + label = label.replace('"', "") label = label.strip() label = label.lower() diff --git a/training/coqui_stt_training/util/io.py b/training/coqui_stt_training/util/io.py index bbd1e19d..a3fb3368 100644 --- a/training/coqui_stt_training/util/io.py +++ b/training/coqui_stt_training/util/io.py @@ -4,6 +4,7 @@ into HDFS storage using Tensorflow's C++ FileStream API. Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets. """ import os + from tensorflow.io import gfile @@ -12,7 +13,7 @@ def is_remote_path(path): Returns True iff the path is one of the remote formats that this module supports """ - return path.startswith('gs://') or path.startswith('hdfs://') + return path.startswith("gs://") or path.startswith("hdfs://") def path_exists_remote(path): @@ -32,7 +33,9 @@ def copy_remote(src, dst, overwrite=False): return gfile.copy(src, dst, overwrite) -def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None): +def open_remote( + path, mode="r", buffering=-1, encoding=None, newline=None, closefd=True, opener=None +): """ Wrapper around open() method that can handle remote paths like `gs://...` off Google Cloud using Tensorflow's IO helpers. @@ -45,7 +48,15 @@ def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, close """ if is_remote_path(path): return gfile.GFile(path, mode=mode) - return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener) + return open( + path, + mode, + buffering=buffering, + encoding=encoding, + newline=newline, + closefd=closefd, + opener=opener, + ) def isdir_remote(path): diff --git a/training/coqui_stt_training/util/logging.py b/training/coqui_stt_training/util/logging.py index a9f2d3d6..a416bd98 100644 --- a/training/coqui_stt_training/util/logging.py +++ b/training/coqui_stt_training/util/logging.py @@ -1,42 +1,43 @@ from __future__ import print_function -import progressbar import sys -from .flags import FLAGS +import progressbar +from .flags import FLAGS # Logging functions # ================= + def prefix_print(prefix, message): - print(prefix + ('\n' + prefix).join(message.split('\n'))) + print(prefix + ("\n" + prefix).join(message.split("\n"))) def log_debug(message): if FLAGS.log_level == 0: - prefix_print('D ', message) + prefix_print("D ", message) def log_info(message): if FLAGS.log_level <= 1: - prefix_print('I ', message) + prefix_print("I ", message) def log_warn(message): if FLAGS.log_level <= 2: - prefix_print('W ', message) + prefix_print("W ", message) def log_error(message): if FLAGS.log_level <= 3: - prefix_print('E ', message) + prefix_print("E ", message) def create_progressbar(*args, **kwargs): # Progress bars in stdout by default - if 'fd' not in kwargs: - kwargs['fd'] = sys.stdout + if "fd" not in kwargs: + kwargs["fd"] = sys.stdout if FLAGS.show_progressbar: return progressbar.ProgressBar(*args, **kwargs) diff --git a/training/coqui_stt_training/util/sample_collections.py b/training/coqui_stt_training/util/sample_collections.py index 5a467d50..23acf70f 100644 --- a/training/coqui_stt_training/util/sample_collections.py +++ b/training/coqui_stt_training/util/sample_collections.py @@ -1,45 +1,47 @@ # -*- coding: utf-8 -*- -import os -import io import csv +import io import json +import os import tarfile - -from pathlib import Path from functools import partial +from pathlib import Path -from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap from .audio import ( - Sample, - AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS, + AUDIO_TYPE_PCM, SERIALIZABLE_AUDIO_TYPES, + Sample, get_loadable_audio_type_from_extension, - write_wav + write_wav, ) -from .io import open_remote, is_remote_path +from .helpers import GIGABYTE, KILOBYTE, MEGABYTE, Interleaved, LenMap +from .io import is_remote_path, open_remote -BIG_ENDIAN = 'big' +BIG_ENDIAN = "big" INT_SIZE = 4 BIGINT_SIZE = 2 * INT_SIZE -MAGIC = b'SAMPLEDB' +MAGIC = b"SAMPLEDB" BUFFER_SIZE = 1 * MEGABYTE REVERSE_BUFFER_SIZE = 16 * KILOBYTE CACHE_SIZE = 1 * GIGABYTE -SCHEMA_KEY = 'schema' -CONTENT_KEY = 'content' -MIME_TYPE_KEY = 'mime-type' -MIME_TYPE_TEXT = 'text/plain' -CONTENT_TYPE_SPEECH = 'speech' -CONTENT_TYPE_TRANSCRIPT = 'transcript' +SCHEMA_KEY = "schema" +CONTENT_KEY = "content" +MIME_TYPE_KEY = "mime-type" +MIME_TYPE_TEXT = "text/plain" +CONTENT_TYPE_SPEECH = "speech" +CONTENT_TYPE_TRANSCRIPT = "transcript" class LabeledSample(Sample): """In-memory labeled audio sample representing an utterance. Derived from util.audio.Sample and used by sample collection readers and writers.""" - def __init__(self, audio_type, raw_data, transcript, audio_format=None, sample_id=None): + + def __init__( + self, audio_type, raw_data, transcript, audio_format=None, sample_id=None + ): """ Parameters ---------- @@ -55,7 +57,9 @@ class LabeledSample(Sample): Tracking ID - should indicate sample's origin as precisely as possible. It is typically assigned by collection readers. """ - super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id) + super().__init__( + audio_type, raw_data, audio_format=audio_format, sample_id=sample_id + ) self.transcript = transcript @@ -65,13 +69,14 @@ class PackedSample: have the child process do the loading/unpacking of the sample, allowing for parallel file I/O. """ + def __init__(self, filename, audio_type, label): self.filename = filename self.audio_type = audio_type self.label = label def unpack(self): - with open_remote(self.filename, 'rb') as audio_file: + with open_remote(self.filename, "rb") as audio_file: data = audio_file.read() if self.label is None: s = Sample(self.audio_type, data, sample_id=self.filename) @@ -83,7 +88,7 @@ def unpack_maybe(sample): """ Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already. """ - if hasattr(sample, 'unpack'): + if hasattr(sample, "unpack"): realized_sample = sample.unpack() else: realized_sample = sample @@ -117,13 +122,16 @@ def load_sample(filename, label=None): class DirectSDBWriter: """Sample collection writer for creating a Sample DB (SDB) file""" - def __init__(self, - sdb_filename, - buffering=BUFFER_SIZE, - audio_type=AUDIO_TYPE_OPUS, - bitrate=None, - id_prefix=None, - labeled=True): + + def __init__( + self, + sdb_filename, + buffering=BUFFER_SIZE, + audio_type=AUDIO_TYPE_OPUS, + bitrate=None, + id_prefix=None, + labeled=True, + ): """ Parameters ---------- @@ -148,7 +156,7 @@ class DirectSDBWriter: raise ValueError('Audio type "{}" not supported'.format(audio_type)) self.audio_type = audio_type self.bitrate = bitrate - self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering) + self.sdb_file = open_remote(sdb_filename, "wb", buffering=buffering) self.offsets = [] self.num_samples = 0 @@ -156,7 +164,9 @@ class DirectSDBWriter: schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}] if self.labeled: - schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}) + schema_entries.append( + {CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT} + ) meta_data = {SCHEMA_KEY: schema_entries} meta_data = json.dumps(meta_data).encode() self.write_big_int(len(meta_data)) @@ -177,20 +187,23 @@ class DirectSDBWriter: def add(self, sample): def to_bytes(n): return n.to_bytes(INT_SIZE, BIG_ENDIAN) + sample.change_audio_type(self.audio_type, bitrate=self.bitrate) opus = sample.audio.getbuffer() opus_len = to_bytes(len(opus)) if self.labeled: transcript = sample.transcript.encode() transcript_len = to_bytes(len(transcript)) - entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript)) - buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript]) + entry_len = to_bytes( + len(opus_len) + len(opus) + len(transcript_len) + len(transcript) + ) + buffer = b"".join([entry_len, opus_len, opus, transcript_len, transcript]) else: entry_len = to_bytes(len(opus_len) + len(opus)) - buffer = b''.join([entry_len, opus_len, opus]) + buffer = b"".join([entry_len, opus_len, opus]) self.offsets.append(self.sdb_file.tell()) self.sdb_file.write(buffer) - sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples) + sample.sample_id = "{}:{}".format(self.id_prefix, self.num_samples) self.num_samples += 1 return sample.sample_id @@ -221,12 +234,15 @@ class DirectSDBWriter: class SDB: # pylint: disable=too-many-instance-attributes """Sample collection reader for reading a Sample DB (SDB) file""" - def __init__(self, - sdb_filename, - buffering=BUFFER_SIZE, - id_prefix=None, - labeled=True, - reverse=False): + + def __init__( + self, + sdb_filename, + buffering=BUFFER_SIZE, + id_prefix=None, + labeled=True, + reverse=False, + ): """ Parameters ---------- @@ -244,30 +260,36 @@ class SDB: # pylint: disable=too-many-instance-attributes """ self.sdb_filename = sdb_filename self.id_prefix = sdb_filename if id_prefix is None else id_prefix - self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) + self.sdb_file = open_remote( + sdb_filename, "rb", buffering=REVERSE_BUFFER_SIZE if reverse else buffering + ) self.offsets = [] if self.sdb_file.read(len(MAGIC)) != MAGIC: - raise RuntimeError('No Sample Database') + raise RuntimeError("No Sample Database") meta_chunk_len = self.read_big_int() self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode()) if SCHEMA_KEY not in self.meta: - raise RuntimeError('Missing schema') + raise RuntimeError("Missing schema") self.schema = self.meta[SCHEMA_KEY] - speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES) + speech_columns = self.find_columns( + content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES + ) if not speech_columns: - raise RuntimeError('No speech data (missing in schema)') + raise RuntimeError("No speech data (missing in schema)") self.speech_index = speech_columns[0] self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY] self.transcript_index = None if labeled is not False: - transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT) + transcript_columns = self.find_columns( + content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT + ) if transcript_columns: self.transcript_index = transcript_columns[0] else: if labeled is True: - raise RuntimeError('No transcript data (missing in schema)') + raise RuntimeError("No transcript data (missing in schema)") sample_chunk_len = self.read_big_int() self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1) @@ -290,12 +312,16 @@ class SDB: # pylint: disable=too-many-instance-attributes if mime_type is not None: criteria.append((MIME_TYPE_KEY, mime_type)) if len(criteria) == 0: - raise ValueError('At least one of "content" or "mime-type" has to be provided') + raise ValueError( + 'At least one of "content" or "mime-type" has to be provided' + ) matches = [] for index, column in enumerate(self.schema): matched = 0 for field, value in criteria: - if column[field] == value or (isinstance(value, list) and column[field] in value): + if column[field] == value or ( + isinstance(value, list) and column[field] in value + ): matched += 1 if matched == len(criteria): matches.append(index) @@ -306,8 +332,11 @@ class SDB: # pylint: disable=too-many-instance-attributes column_data = [None] * len(columns) found = 0 if not 0 <= row_index < len(self.offsets): - raise ValueError('Wrong sample index: {} - has to be between 0 and {}' - .format(row_index, len(self.offsets) - 1)) + raise ValueError( + "Wrong sample index: {} - has to be between 0 and {}".format( + row_index, len(self.offsets) - 1 + ) + ) self.sdb_file.seek(self.offsets[row_index] + INT_SIZE) for index in range(len(self.schema)): chunk_len = self.read_int() @@ -321,13 +350,17 @@ class SDB: # pylint: disable=too-many-instance-attributes return tuple(column_data) def __getitem__(self, i): - sample_id = '{}:{}'.format(self.id_prefix, i) + sample_id = "{}:{}".format(self.id_prefix, i) if self.transcript_index is None: [audio_data] = self.read_row(i, self.speech_index) return Sample(self.audio_type, audio_data, sample_id=sample_id) - audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index) + audio_data, transcript = self.read_row( + i, self.speech_index, self.transcript_index + ) transcript = transcript.decode() - return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id) + return LabeledSample( + self.audio_type, audio_data, transcript, sample_id=sample_id + ) def __iter__(self): for i in range(len(self.offsets)): @@ -346,10 +379,8 @@ class SDB: # pylint: disable=too-many-instance-attributes class CSVWriter: # pylint: disable=too-many-instance-attributes """Sample collection writer for writing a CSV data-set and all its referenced WAV samples""" - def __init__(self, - csv_filename, - absolute_paths=False, - labeled=True): + + def __init__(self, csv_filename, absolute_paths=False, labeled=True): """ Parameters ---------- @@ -361,7 +392,7 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes labeled : bool or None If True: Writes labeled samples (util.sample_collections.LabeledSample) only. If False: Ignores transcripts (if available) and writes (unlabeled) util.audio.Sample instances. - + Currently only works with local files (not gs:// or hdfs://...) """ self.csv_filename = Path(csv_filename) @@ -372,11 +403,11 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes raise RuntimeError('"{}" already existing'.format(self.csv_dir)) os.mkdir(str(self.csv_dir)) self.absolute_paths = absolute_paths - fieldnames = ['wav_filename', 'wav_filesize'] + fieldnames = ["wav_filename", "wav_filesize"] self.labeled = labeled if labeled: - fieldnames.append('transcript') - self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='') + fieldnames.append("transcript") + self.csv_file = open_remote(csv_filename, "w", encoding="utf-8", newline="") self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer.writeheader() self.counter = 0 @@ -385,17 +416,19 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes return self def add(self, sample): - sample_filename = self.csv_dir / 'sample{0:08d}.wav'.format(self.counter) + sample_filename = self.csv_dir / "sample{0:08d}.wav".format(self.counter) self.counter += 1 sample.change_audio_type(AUDIO_TYPE_PCM) write_wav(str(sample_filename), sample.audio, audio_format=sample.audio_format) sample.sample_id = str(sample_filename.relative_to(self.csv_base_dir)) row = { - 'wav_filename': str(sample_filename.absolute()) if self.absolute_paths else sample.sample_id, - 'wav_filesize': sample_filename.stat().st_size + "wav_filename": str(sample_filename.absolute()) + if self.absolute_paths + else sample.sample_id, + "wav_filesize": sample_filename.stat().st_size, } if self.labeled: - row['transcript'] = sample.transcript + row["transcript"] = sample.transcript self.csv_writer.writerow(row) return sample.sample_id @@ -412,11 +445,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes class TarWriter: # pylint: disable=too-many-instance-attributes """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file.""" - def __init__(self, - tar_filename, - gz=False, - labeled=True, - include=None): + + def __init__(self, tar_filename, gz=False, labeled=True, include=None): """ Parameters ---------- @@ -432,17 +462,19 @@ class TarWriter: # pylint: disable=too-many-instance-attributes Currently only works with local files (not gs:// or hdfs://...) """ - self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') - samples_dir = tarfile.TarInfo('samples') + self.tar = tarfile.open(tar_filename, "w:gz" if gz else "w") + samples_dir = tarfile.TarInfo("samples") samples_dir.type = tarfile.DIRTYPE self.tar.addfile(samples_dir) if include: for include_path in include: - self.tar.add(include_path, recursive=False, arcname=Path(include_path).name) - fieldnames = ['wav_filename', 'wav_filesize'] + self.tar.add( + include_path, recursive=False, arcname=Path(include_path).name + ) + fieldnames = ["wav_filename", "wav_filesize"] self.labeled = labeled if labeled: - fieldnames.append('transcript') + fieldnames.append("transcript") self.csv_file = io.StringIO() self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer.writeheader() @@ -452,7 +484,7 @@ class TarWriter: # pylint: disable=too-many-instance-attributes return self def add(self, sample): - sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter) + sample_filename = "samples/sample{0:08d}.wav".format(self.counter) self.counter += 1 sample.change_audio_type(AUDIO_TYPE_PCM) sample_file = io.BytesIO() @@ -462,21 +494,18 @@ class TarWriter: # pylint: disable=too-many-instance-attributes sample_tar = tarfile.TarInfo(sample_filename) sample_tar.size = sample_size self.tar.addfile(sample_tar, sample_file) - row = { - 'wav_filename': sample_filename, - 'wav_filesize': sample_size - } + row = {"wav_filename": sample_filename, "wav_filesize": sample_size} if self.labeled: - row['transcript'] = sample.transcript + row["transcript"] = sample.transcript self.csv_writer.writerow(row) return sample_filename def close(self): if self.csv_file and self.tar: - csv_tar = tarfile.TarInfo('samples.csv') + csv_tar = tarfile.TarInfo("samples.csv") csv_tar.size = self.csv_file.tell() self.csv_file.seek(0) - self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8'))) + self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode("utf8"))) if self.tar: self.tar.close() @@ -489,6 +518,7 @@ class TarWriter: # pylint: disable=too-many-instance-attributes class SampleList: """Sample collection base class with samples loaded from a list of in-memory paths.""" + def __init__(self, samples, labeled=True, reverse=False): """ Parameters @@ -507,7 +537,9 @@ class SampleList: def __getitem__(self, i): sample_spec = self.samples[i] - return load_sample(sample_spec[0], label=sample_spec[2] if self.labeled else None) + return load_sample( + sample_spec[0], label=sample_spec[2] if self.labeled else None + ) def __len__(self): return len(self.samples) @@ -516,6 +548,7 @@ class SampleList: class CSV(SampleList): """Sample collection reader for reading a Coqui STT CSV file Automatically orders samples by CSV column wav_filesize (if available).""" + def __init__(self, csv_filename, labeled=None, reverse=False): """ Parameters @@ -531,30 +564,34 @@ class CSV(SampleList): If the order of the samples should be reversed """ rows = [] - with open_remote(csv_filename, 'r', encoding='utf8') as csv_file: + with open_remote(csv_filename, "r", encoding="utf8") as csv_file: reader = csv.DictReader(csv_file) - if 'transcript' in reader.fieldnames: + if "transcript" in reader.fieldnames: if labeled is None: labeled = True elif labeled: - raise RuntimeError('No transcript data (missing CSV column)') + raise RuntimeError("No transcript data (missing CSV column)") for row in reader: - wav_filename = Path(row['wav_filename']) - if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']): + wav_filename = Path(row["wav_filename"]) + if not wav_filename.is_absolute() and not is_remote_path( + row["wav_filename"] + ): wav_filename = Path(csv_filename).parent / wav_filename wav_filename = str(wav_filename) else: # Pathlib otherwise removes a / from filenames like hdfs:// - wav_filename = row['wav_filename'] - wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 + wav_filename = row["wav_filename"] + wav_filesize = int(row["wav_filesize"]) if "wav_filesize" in row else 0 if labeled: - rows.append((wav_filename, wav_filesize, row['transcript'])) + rows.append((wav_filename, wav_filesize, row["transcript"])) else: rows.append((wav_filename, wav_filesize)) super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse) -def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False): +def samples_from_source( + sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False +): """ Loads samples from a sample source file. @@ -577,14 +614,16 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reve iterable of util.sample_collections.LabeledSample or util.audio.Sample instances supporting len. """ ext = os.path.splitext(sample_source)[1].lower() - if ext == '.sdb': + if ext == ".sdb": return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse) - if ext == '.csv': + if ext == ".csv": return CSV(sample_source, labeled=labeled, reverse=reverse) raise ValueError('Unknown file type: "{}"'.format(ext)) -def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False): +def samples_from_sources( + sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False +): """ Loads and combines samples from a list of source files. Sources are combined in an interleaving way to keep default sample order from shortest to longest. @@ -616,14 +655,22 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re """ sample_sources = list(sample_sources) if len(sample_sources) == 0: - raise ValueError('No files') + raise ValueError("No files") if len(sample_sources) == 1: - return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse) + return samples_from_source( + sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse + ) # If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should # be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code. - cols = [LenMap( - unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)) - for source in sample_sources] + cols = [ + LenMap( + unpack_maybe, + samples_from_source( + source, buffering=buffering, labeled=labeled, reverse=reverse + ), + ) + for source in sample_sources + ] return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse) diff --git a/training/coqui_stt_training/util/stm.py b/training/coqui_stt_training/util/stm.py index 23383341..ee942a12 100644 --- a/training/coqui_stt_training/util/stm.py +++ b/training/coqui_stt_training/util/stm.py @@ -1,27 +1,31 @@ import codecs import unicodedata + class STMSegment(object): r""" Representation of an individual segment in an STM file. """ + def __init__(self, stm_line): tokens = stm_line.split() - self._filename = tokens[0] - self._channel = tokens[1] - self._speaker_id = tokens[2] - self._start_time = float(tokens[3]) - self._stop_time = float(tokens[4]) - self._labels = tokens[5] - self._transcript = "" + self._filename = tokens[0] + self._channel = tokens[1] + self._speaker_id = tokens[2] + self._start_time = float(tokens[3]) + self._stop_time = float(tokens[4]) + self._labels = tokens[5] + self._transcript = "" for token in tokens[6:]: - self._transcript += token + " " + self._transcript += token + " " # We need to do the encode-decode dance here because encode # returns a bytes() object on Python 3, and text_to_char_array # expects a string. - self._transcript = unicodedata.normalize("NFKD", self._transcript.strip()) \ - .encode("ascii", "ignore") \ - .decode("ascii", "ignore") + self._transcript = ( + unicodedata.normalize("NFKD", self._transcript.strip()) + .encode("ascii", "ignore") + .decode("ascii", "ignore") + ) @property def filename(self): @@ -51,6 +55,7 @@ class STMSegment(object): def transcript(self): return self._transcript + def parse_stm_file(stm_file): r""" Parses an STM file at ``stm_file`` into a list of :class:`STMSegment`. diff --git a/training/coqui_stt_training/util/text.py b/training/coqui_stt_training/util/text.py index 198bd96e..2343882a 100644 --- a/training/coqui_stt_training/util/text.py +++ b/training/coqui_stt_training/util/text.py @@ -1,9 +1,11 @@ from __future__ import absolute_import, division, print_function -import numpy as np import struct -def text_to_char_array(transcript, alphabet, context=''): +import numpy as np + + +def text_to_char_array(transcript, alphabet, context=""): r""" Given a transcript string, map characters to integers and return a numpy array representing the processed string. @@ -13,15 +15,20 @@ def text_to_char_array(transcript, alphabet, context=''): # Provide the row context (especially wav_filename) for alphabet errors raise ValueError( 'Alphabet cannot encode transcript "{}" while processing sample "{}", ' - 'check that your alphabet contains all characters in the training corpus. ' - 'Missing characters are: {}.' - .format(transcript, context, list(ch for ch in transcript if not alphabet.CanEncodeSingle(ch)))) + "check that your alphabet contains all characters in the training corpus. " + "Missing characters are: {}.".format( + transcript, + context, + list(ch for ch in transcript if not alphabet.CanEncodeSingle(ch)), + ) + ) encoded = alphabet.Encode(transcript) if len(encoded) == 0: - raise ValueError('While processing {}: Found an empty transcript! ' - 'You must include a transcript for all training data.' - .format(context)) + raise ValueError( + "While processing {}: Found an empty transcript! " + "You must include a transcript for all training data.".format(context) + ) return encoded @@ -35,6 +42,7 @@ def text_to_char_array(transcript, alphabet, context=''): # version 1.0. This software is distributed without any warranty. For more # information, see + def levenshtein(a, b): "Calculates the Levenshtein distance between a and b." n, m = len(a), len(b) @@ -43,13 +51,13 @@ def levenshtein(a, b): a, b = b, a n, m = m, n - current = list(range(n+1)) - for i in range(1, m+1): - previous, current = current, [i]+[0]*n - for j in range(1, n+1): - add, delete = previous[j]+1, current[j-1]+1 - change = previous[j-1] - if a[j-1] != b[i-1]: + current = list(range(n + 1)) + for i in range(1, m + 1): + previous, current = current, [i] + [0] * n + for j in range(1, n + 1): + add, delete = previous[j] + 1, current[j - 1] + 1 + change = previous[j - 1] + if a[j - 1] != b[i - 1]: change = change + 1 current[j] = min(add, delete, change) diff --git a/transcribe.py b/transcribe.py index c6d79dab..6ca6d441 100755 --- a/transcribe.py +++ b/transcribe.py @@ -2,24 +2,32 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function +import json import os import sys -import json -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import tensorflow as tf import tensorflow.compat.v1.logging as tflogging + tflogging.set_verbosity(tflogging.ERROR) import logging -logging.getLogger('sox').setLevel(logging.ERROR) -import glob +logging.getLogger("sox").setLevel(logging.ERROR) +import glob +from multiprocessing import Process, cpu_count + +from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch from coqui_stt_training.util.audio import AudioFile from coqui_stt_training.util.config import Config, initialize_globals from coqui_stt_training.util.feeding import split_audio_file -from coqui_stt_training.util.flags import create_flags, FLAGS -from coqui_stt_training.util.logging import log_error, log_info, log_progress, create_progressbar -from coqui_stt_ctcdecoder import ctc_beam_search_decoder_batch, Scorer -from multiprocessing import Process, cpu_count +from coqui_stt_training.util.flags import FLAGS, create_flags +from coqui_stt_training.util.logging import ( + create_progressbar, + log_error, + log_info, + log_progress, +) def fail(message, code=1): @@ -28,8 +36,11 @@ def fail(message, code=1): def transcribe_file(audio_path, tlog_path): - from coqui_stt_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel + from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel + create_model, + ) from coqui_stt_training.util.checkpoints import load_graph_for_evaluation + initialize_globals() scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) try: @@ -37,16 +48,23 @@ def transcribe_file(audio_path, tlog_path): except NotImplementedError: num_processes = 1 with AudioFile(audio_path, as_path=True) as wav_path: - data_set = split_audio_file(wav_path, - batch_size=FLAGS.batch_size, - aggressiveness=FLAGS.vad_aggressiveness, - outlier_duration_ms=FLAGS.outlier_duration_ms, - outlier_batch_size=FLAGS.outlier_batch_size) - iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes, - output_classes=data_set.output_classes) + data_set = split_audio_file( + wav_path, + batch_size=FLAGS.batch_size, + aggressiveness=FLAGS.vad_aggressiveness, + outlier_duration_ms=FLAGS.outlier_duration_ms, + outlier_batch_size=FLAGS.outlier_batch_size, + ) + iterator = tf.data.Iterator.from_structure( + data_set.output_types, + data_set.output_shapes, + output_classes=data_set.output_classes, + ) batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next() no_dropout = [None] * 6 - logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) + logits, _ = create_model( + batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout + ) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) tf.train.get_or_create_global_step() with tf.Session(config=Config.session_config) as session: @@ -55,30 +73,43 @@ def transcribe_file(audio_path, tlog_path): transcripts = [] while True: try: - starts, ends, batch_logits, batch_lengths = \ - session.run([batch_time_start, batch_time_end, transposed, batch_x_len]) + starts, ends, batch_logits, batch_lengths = session.run( + [batch_time_start, batch_time_end, transposed, batch_x_len] + ) except tf.errors.OutOfRangeError: break - decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, - num_processes=num_processes, - scorer=scorer) + decoded = ctc_beam_search_decoder_batch( + batch_logits, + batch_lengths, + Config.alphabet, + FLAGS.beam_width, + num_processes=num_processes, + scorer=scorer, + ) decoded = list(d[0][1] for d in decoded) transcripts.extend(zip(starts, ends, decoded)) transcripts.sort(key=lambda t: t[0]) - transcripts = [{'start': int(start), - 'end': int(end), - 'transcript': transcript} for start, end, transcript in transcripts] - with open(tlog_path, 'w') as tlog_file: + transcripts = [ + {"start": int(start), "end": int(end), "transcript": transcript} + for start, end, transcript in transcripts + ] + with open(tlog_path, "w") as tlog_file: json.dump(transcripts, tlog_file, default=float) -def transcribe_many(src_paths,dst_paths): - pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(src_paths)).start() +def transcribe_many(src_paths, dst_paths): + pbar = create_progressbar( + prefix="Transcribing files | ", max_value=len(src_paths) + ).start() for i in range(len(src_paths)): p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i])) p.start() p.join() - log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(src_paths), src_paths[i], dst_paths[i])) + log_progress( + 'Transcribed file {} of {} from "{}" to "{}"'.format( + i + 1, len(src_paths), src_paths[i], dst_paths[i] + ) + ) pbar.update(i) pbar.finish() @@ -99,70 +130,116 @@ def resolve(base_path, spec_path): def main(_): if not FLAGS.src or not os.path.exists(FLAGS.src): # path not given or non-existant - fail('You have to specify which file or catalog to transcribe via the --src flag.') + fail( + "You have to specify which file or catalog to transcribe via the --src flag." + ) else: # path given and exists src_path = os.path.abspath(FLAGS.src) if os.path.isfile(src_path): - if src_path.endswith('.catalog'): + if src_path.endswith(".catalog"): # Transcribe batch of files via ".catalog" file (from DSAlign) if FLAGS.dst: - fail('Parameter --dst not supported if --src points to a catalog') + fail("Parameter --dst not supported if --src points to a catalog") catalog_dir = os.path.dirname(src_path) - with open(src_path, 'r') as catalog_file: + with open(src_path, "r") as catalog_file: catalog_entries = json.load(catalog_file) - catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries] + catalog_entries = [ + (resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"])) + for e in catalog_entries + ] if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)): - fail('Missing source file(s) in catalog') - if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)): - fail('Destination file(s) from catalog already existing, use --force for overwriting') - if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)): - fail('Missing destination directory for at least one catalog entry') - src_paths,dst_paths = zip(*paths) - transcribe_many(src_paths,dst_paths) + fail("Missing source file(s) in catalog") + if not FLAGS.force and any( + map(lambda e: os.path.isfile(e[1]), catalog_entries) + ): + fail( + "Destination file(s) from catalog already existing, use --force for overwriting" + ) + if any( + map( + lambda e: not os.path.isdir(os.path.dirname(e[1])), + catalog_entries, + ) + ): + fail("Missing destination directory for at least one catalog entry") + src_paths, dst_paths = zip(*paths) + transcribe_many(src_paths, dst_paths) else: # Transcribe one file - dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog' + dst_path = ( + os.path.abspath(FLAGS.dst) + if FLAGS.dst + else os.path.splitext(src_path)[0] + ".tlog" + ) if os.path.isfile(dst_path): if FLAGS.force: transcribe_one(src_path, dst_path) else: - fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0) + fail( + 'Destination file "{}" already existing - use --force for overwriting'.format( + dst_path + ), + code=0, + ) elif os.path.isdir(os.path.dirname(dst_path)): transcribe_one(src_path, dst_path) else: - fail('Missing destination directory') + fail("Missing destination directory") elif os.path.isdir(src_path): # Transcribe all files in dir print("Transcribing all WAV files in --src") if FLAGS.dst: - fail('Destination file not supported for batch decoding jobs.') + fail("Destination file not supported for batch decoding jobs.") else: if not FLAGS.recursive: - print("If you wish to recursively scan --src, then you must use --recursive") + print( + "If you wish to recursively scan --src, then you must use --recursive" + ) wav_paths = glob.glob(src_path + "/*.wav") else: wav_paths = glob.glob(src_path + "/**/*.wav") - dst_paths = [path.replace('.wav','.tlog') for path in wav_paths] - transcribe_many(wav_paths,dst_paths) + dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths] + transcribe_many(wav_paths, dst_paths) -if __name__ == '__main__': +if __name__ == "__main__": create_flags() - tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.' - 'Catalog files should be formatted from DSAlign. A directory will' - 'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be ' - 'written in-place using the source filenames with ' - 'suffix ".tlog" instead of ".wav".') - tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). ' - 'If --src is a directory, this one also has to be a directory ' - 'and the required sub-dir tree of --src will get replicated.') - tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively') - tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing ' - 'transcription logs (.tlog)') - tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should ' - 'split audio') - tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size') - tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers') - tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)') + tf.app.flags.DEFINE_string( + "src", + "", + "Source path to an audio file or directory or catalog file." + "Catalog files should be formatted from DSAlign. A directory will" + "be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be " + "written in-place using the source filenames with " + 'suffix ".tlog" instead of ".wav".', + ) + tf.app.flags.DEFINE_string( + "dst", + "", + "path for writing the transcription log or logs (.tlog). " + "If --src is a directory, this one also has to be a directory " + "and the required sub-dir tree of --src will get replicated.", + ) + tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively") + tf.app.flags.DEFINE_boolean( + "force", + False, + "Forces re-transcribing and overwriting of already existing " + "transcription logs (.tlog)", + ) + tf.app.flags.DEFINE_integer( + "vad_aggressiveness", + 3, + "How aggressive (0=lowest, 3=highest) the VAD should " "split audio", + ) + tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size") + tf.app.flags.DEFINE_float( + "outlier_duration_ms", + 10000, + "Duration in ms after which samples are considered outliers", + ) + tf.app.flags.DEFINE_integer( + "outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)" + ) tf.app.run(main)