Run pre-commit hooks on all files

This commit is contained in:
Reuben Morais 2021-05-18 13:45:52 +02:00
parent 14aee5d35b
commit 43a6c3e62a
140 changed files with 4008 additions and 2214 deletions

View File

@ -22,7 +22,3 @@ repos:
- id: isort - id: isort
name: isort (pyi) name: isort (pyi)
types: [pyi] types: [pyi]
- repo: https://github.com/pycqa/pylint
rev: v2.8.2
hooks:
- id: pylint

View File

@ -67,4 +67,3 @@ Links & Resources
- `see the latest release on GitHub <https://github.com/coqui-ai/STT/releases/latest>`_ - `see the latest release on GitHub <https://github.com/coqui-ai/STT/releases/latest>`_
* - 🤝 **Contribution Guidelines** * - 🤝 **Contribution Guidelines**
- `CONTRIBUTING.rst <CONTRIBUTING.rst>`_ - `CONTRIBUTING.rst <CONTRIBUTING.rst>`_

View File

@ -2,10 +2,10 @@
""" """
Tool for comparing two wav samples Tool for comparing two wav samples
""" """
import sys
import argparse import argparse
import numpy as np import sys
import numpy as np
from coqui_stt_training.util.audio import AUDIO_TYPE_NP, mean_dbfs from coqui_stt_training.util.audio import AUDIO_TYPE_NP, mean_dbfs
from coqui_stt_training.util.sample_collections import load_sample from coqui_stt_training.util.sample_collections import load_sample
@ -19,19 +19,29 @@ def compare_samples():
sample1 = load_sample(CLI_ARGS.sample1).unpack() sample1 = load_sample(CLI_ARGS.sample1).unpack()
sample2 = load_sample(CLI_ARGS.sample2).unpack() sample2 = load_sample(CLI_ARGS.sample2).unpack()
if sample1.audio_format != sample2.audio_format: if sample1.audio_format != sample2.audio_format:
fail('Samples differ on: audio-format ({} and {})'.format(sample1.audio_format, sample2.audio_format)) fail(
"Samples differ on: audio-format ({} and {})".format(
sample1.audio_format, sample2.audio_format
)
)
if abs(sample1.duration - sample2.duration) > 0.001: if abs(sample1.duration - sample2.duration) > 0.001:
fail('Samples differ on: duration ({} and {})'.format(sample1.duration, sample2.duration)) fail(
"Samples differ on: duration ({} and {})".format(
sample1.duration, sample2.duration
)
)
sample1.change_audio_type(AUDIO_TYPE_NP) sample1.change_audio_type(AUDIO_TYPE_NP)
sample2.change_audio_type(AUDIO_TYPE_NP) sample2.change_audio_type(AUDIO_TYPE_NP)
samples = [sample1, sample2] samples = [sample1, sample2]
largest = np.argmax([sample1.audio.shape[0], sample2.audio.shape[0]]) largest = np.argmax([sample1.audio.shape[0], sample2.audio.shape[0]])
smallest = (largest + 1) % 2 smallest = (largest + 1) % 2
samples[largest].audio = samples[largest].audio[:len(samples[smallest].audio)] samples[largest].audio = samples[largest].audio[: len(samples[smallest].audio)]
audio_diff = samples[largest].audio - samples[smallest].audio audio_diff = samples[largest].audio - samples[smallest].audio
diff_dbfs = mean_dbfs(audio_diff) diff_dbfs = mean_dbfs(audio_diff)
differ_msg = 'Samples differ on: sample data ({:0.2f} dB difference) '.format(diff_dbfs) differ_msg = "Samples differ on: sample data ({:0.2f} dB difference) ".format(
equal_msg = 'Samples are considered equal ({:0.2f} dB difference)'.format(diff_dbfs) diff_dbfs
)
equal_msg = "Samples are considered equal ({:0.2f} dB difference)".format(diff_dbfs)
if CLI_ARGS.if_differ: if CLI_ARGS.if_differ:
if diff_dbfs <= CLI_ARGS.threshold: if diff_dbfs <= CLI_ARGS.threshold:
fail(equal_msg) fail(equal_msg)
@ -50,13 +60,17 @@ def handle_args():
) )
parser.add_argument("sample1", help="Filename of sample 1 to compare") parser.add_argument("sample1", help="Filename of sample 1 to compare")
parser.add_argument("sample2", help="Filename of sample 2 to compare") parser.add_argument("sample2", help="Filename of sample 2 to compare")
parser.add_argument("--threshold", type=float, default=-60.0, parser.add_argument(
help="dB of sample deltas above which they are considered different") "--threshold",
type=float,
default=-60.0,
help="dB of sample deltas above which they are considered different",
)
parser.add_argument( parser.add_argument(
"--if-differ", "--if-differ",
action="store_true", action="store_true",
help="If to succeed and return status code 0 on different signals and fail on equal ones (inverse check)." help="If to succeed and return status code 0 on different signals and fail on equal ones (inverse check)."
"This will still fail on different formats or durations.", "This will still fail on different formats or durations.",
) )
parser.add_argument( parser.add_argument(
"--no-success-output", "--no-success-output",

View File

@ -1,19 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
''' """
Tool for building a combined SDB or CSV sample-set from other sets Tool for building a combined SDB or CSV sample-set from other sets
Use 'python3 data_set_tool.py -h' for help Use 'python3 data_set_tool.py -h' for help
''' """
import sys
import argparse import argparse
import progressbar import sys
from pathlib import Path from pathlib import Path
import progressbar
from coqui_stt_training.util.audio import ( from coqui_stt_training.util.audio import (
AUDIO_TYPE_PCM,
AUDIO_TYPE_OPUS, AUDIO_TYPE_OPUS,
AUDIO_TYPE_PCM,
AUDIO_TYPE_WAV, AUDIO_TYPE_WAV,
change_audio_types, change_audio_types,
) )
from coqui_stt_training.util.augmentations import (
SampleAugmentation,
apply_sample_augmentations,
parse_augmentations,
)
from coqui_stt_training.util.downloader import SIMPLE_BAR from coqui_stt_training.util.downloader import SIMPLE_BAR
from coqui_stt_training.util.sample_collections import ( from coqui_stt_training.util.sample_collections import (
CSVWriter, CSVWriter,
@ -21,101 +26,110 @@ from coqui_stt_training.util.sample_collections import (
TarWriter, TarWriter,
samples_from_sources, samples_from_sources,
) )
from coqui_stt_training.util.augmentations import (
parse_augmentations,
apply_sample_augmentations,
SampleAugmentation
)
AUDIO_TYPE_LOOKUP = {'wav': AUDIO_TYPE_WAV, 'opus': AUDIO_TYPE_OPUS} AUDIO_TYPE_LOOKUP = {"wav": AUDIO_TYPE_WAV, "opus": AUDIO_TYPE_OPUS}
def build_data_set(): def build_data_set():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type] audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
augmentations = parse_augmentations(CLI_ARGS.augment) augmentations = parse_augmentations(CLI_ARGS.augment)
if any(not isinstance(a, SampleAugmentation) for a in augmentations): if any(not isinstance(a, SampleAugmentation) for a in augmentations):
print('Warning: Some of the specified augmentations will not get applied, as this tool only supports ' print(
'overlay, codec, reverb, resample and volume.') "Warning: Some of the specified augmentations will not get applied, as this tool only supports "
"overlay, codec, reverb, resample and volume."
)
extension = Path(CLI_ARGS.target).suffix.lower() extension = Path(CLI_ARGS.target).suffix.lower()
labeled = not CLI_ARGS.unlabeled labeled = not CLI_ARGS.unlabeled
if extension == '.csv': if extension == ".csv":
writer = CSVWriter(CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled) writer = CSVWriter(
elif extension == '.sdb': CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled
writer = DirectSDBWriter(CLI_ARGS.target, audio_type=audio_type, labeled=labeled) )
elif extension == '.tar': elif extension == ".sdb":
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include) writer = DirectSDBWriter(
elif extension == '.tgz' or CLI_ARGS.target.lower().endswith('.tar.gz'): CLI_ARGS.target, audio_type=audio_type, labeled=labeled
writer = TarWriter(CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include) )
elif extension == ".tar":
writer = TarWriter(
CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include
)
elif extension == ".tgz" or CLI_ARGS.target.lower().endswith(".tar.gz"):
writer = TarWriter(
CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include
)
else: else:
print('Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz') print(
"Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz"
)
sys.exit(1) sys.exit(1)
with writer: with writer:
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled) samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
num_samples = len(samples) num_samples = len(samples)
if augmentations: if augmentations:
samples = apply_sample_augmentations(samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations) samples = apply_sample_augmentations(
samples, audio_type=AUDIO_TYPE_PCM, augmentations=augmentations
)
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for sample in bar(change_audio_types( for sample in bar(
change_audio_types(
samples, samples,
audio_type=audio_type, audio_type=audio_type,
bitrate=CLI_ARGS.bitrate, bitrate=CLI_ARGS.bitrate,
processes=CLI_ARGS.workers)): processes=CLI_ARGS.workers,
)
):
writer.add(sample) writer.add(sample)
def handle_args(): def handle_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Tool for building a combined SDB or CSV sample-set from other sets' description="Tool for building a combined SDB or CSV sample-set from other sets"
) )
parser.add_argument( parser.add_argument(
'sources', "sources",
nargs='+', nargs="+",
help='Source CSV and/or SDB files - ' help="Source CSV and/or SDB files - "
'Note: For getting a correctly ordered target set, source SDBs have to have their samples ' "Note: For getting a correctly ordered target set, source SDBs have to have their samples "
'already ordered from shortest to longest.', "already ordered from shortest to longest.",
) )
parser.add_argument("target", help="SDB, CSV or TAR(.gz) file to create")
parser.add_argument( parser.add_argument(
'target', "--audio-type",
help='SDB, CSV or TAR(.gz) file to create' default="opus",
)
parser.add_argument(
'--audio-type',
default='opus',
choices=AUDIO_TYPE_LOOKUP.keys(), choices=AUDIO_TYPE_LOOKUP.keys(),
help='Audio representation inside target SDB', help="Audio representation inside target SDB",
) )
parser.add_argument( parser.add_argument(
'--bitrate', "--bitrate",
type=int, type=int,
help='Bitrate for lossy compressed SDB samples like in case of --audio-type opus', help="Bitrate for lossy compressed SDB samples like in case of --audio-type opus",
) )
parser.add_argument( parser.add_argument(
'--workers', type=int, default=None, help='Number of encoding SDB workers' "--workers", type=int, default=None, help="Number of encoding SDB workers"
) )
parser.add_argument( parser.add_argument(
'--unlabeled', "--unlabeled",
action='store_true', action="store_true",
help='If to build an data-set with unlabeled (audio only) samples - ' help="If to build an data-set with unlabeled (audio only) samples - "
'typically used for building noise augmentation corpora', "typically used for building noise augmentation corpora",
) )
parser.add_argument( parser.add_argument(
'--absolute-paths', "--absolute-paths",
action='store_true', action="store_true",
help='If to reference samples by their absolute paths when writing CSV files', help="If to reference samples by their absolute paths when writing CSV files",
) )
parser.add_argument( parser.add_argument(
'--augment', "--augment",
action='append', action="append",
help='Add an augmentation operation', help="Add an augmentation operation",
) )
parser.add_argument( parser.add_argument(
'--include', "--include",
action='append', action="append",
help='Adds a file to the root directory of .tar(.gz) targets', help="Adds a file to the root directory of .tar(.gz) targets",
) )
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == "__main__":
CLI_ARGS = handle_args() CLI_ARGS = handle_args()
build_data_set() build_data_set()

View File

@ -3,9 +3,10 @@
import sys import sys
import tensorflow.compat.v1 as tfv1
from google.protobuf import text_format from google.protobuf import text_format
import tensorflow.compat.v1 as tfv1
def main(): def main():
# Load and export as string # Load and export as string

View File

@ -4,7 +4,6 @@ import os
import tarfile import tarfile
import pandas import pandas
from coqui_stt_training.util.importers import get_importers_parser from coqui_stt_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]

View File

@ -4,7 +4,6 @@ import os
import tarfile import tarfile
import pandas import pandas
from coqui_stt_training.util.importers import get_importers_parser from coqui_stt_training.util.importers import get_importers_parser
COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"] COLUMNNAMES = ["wav_filename", "wav_filesize", "transcript"]

View File

@ -5,21 +5,21 @@ Ministère de l'Économie, des Finances et de la Relance
""" """
import csv import csv
import sys import decimal
import hashlib
import math
import os import os
import progressbar import re
import subprocess import subprocess
import sys
import unicodedata
import xml.etree.ElementTree as ET
import zipfile import zipfile
from glob import glob from glob import glob
from multiprocessing import Pool from multiprocessing import Pool
import hashlib import progressbar
import decimal
import math
import unicodedata
import re
import sox import sox
import xml.etree.ElementTree as ET
try: try:
from num2words import num2words from num2words import num2words
@ -27,19 +27,19 @@ except ImportError as ex:
print("pip install num2words") print("pip install num2words")
sys.exit(1) sys.exit(1)
import requests
import json import json
import requests
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.helpers import secs_to_hours from coqui_stt_training.util.helpers import secs_to_hours
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,
get_importers_parser,
get_imported_samples, get_imported_samples,
get_importers_parser,
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from coqui_stt_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -50,58 +50,187 @@ MIN_SECS = 0.85
DATASET_RELEASE_CSV = "https://data.economie.gouv.fr/explore/dataset/transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020/download/?format=csv&timezone=Europe/Berlin&lang=fr&use_labels_for_header=true&csv_separator=%3B" DATASET_RELEASE_CSV = "https://data.economie.gouv.fr/explore/dataset/transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020/download/?format=csv&timezone=Europe/Berlin&lang=fr&use_labels_for_header=true&csv_separator=%3B"
DATASET_RELEASE_SHA = [ DATASET_RELEASE_SHA = [
("863d39a06a388c6491c6ff2f6450b151f38f1b57", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.001"), (
("2f3a0305aa04c61220bb00b5a4e553e45dbf12e1", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.002"), "863d39a06a388c6491c6ff2f6450b151f38f1b57",
("5e55e9f1f844097349188ac875947e5a3d7fe9f1", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.003"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.001",
("8bf54842cf07948ca5915e27a8bd5fa5139c06ae", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.004"), ),
("c8963504aadc015ac48f9af80058a0bb3440b94f", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.005"), (
("d95e225e908621d83ce4e9795fd108d9d310e244", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.006"), "2f3a0305aa04c61220bb00b5a4e553e45dbf12e1",
("de6ed9c2b0ee80ca879aae8ba7923cc93217d811", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.007"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.002",
("234283c47dacfcd4450d836c52c25f3e807fc5f2", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.008"), ),
("4e6b67a688639bb72f8cd81782eaba604a8d32a6", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.009"), (
("4165a51389777c8af8e6253d87bdacb877e8b3b0", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.010"), "5e55e9f1f844097349188ac875947e5a3d7fe9f1",
("34322e7009780d97ef5bd02bf2f2c7a31f00baff", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.011"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.003",
("48c5be3b2ca9d6108d525da6a03e91d93a95dbac", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.012"), ),
("87573172f506a189c2ebc633856fe11a2e9cd213", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.013"), (
("6ab2c9e508e9278d5129f023e018725c4a7c69e8", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.014"), "8bf54842cf07948ca5915e27a8bd5fa5139c06ae",
("4f84df831ef46dce5d3ab3e21817687a2d8c12d0", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.015"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.004",
("e69bfb079885c299cb81080ef88b1b8b57158aa6", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.016"), ),
("5f764ba788ee273981cf211b242c29b49ca22c5e", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.017"), (
("b6aa81a959525363223494830c1e7307d4c4bae6", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.018"), "c8963504aadc015ac48f9af80058a0bb3440b94f",
("91ddcf43c7bf113a6f2528b857c7ec22a50a148a", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.019"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.005",
("fa1b29273dd77b9a7494983a2f9ae52654b931d7", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.020"), ),
("1113aef4f5e2be2f7fbf2d54b6c710c1c0e7135f", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.021"), (
("ce6420d5d0b6b5135ba559f83e1a82d4d615c470", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.022"), "d95e225e908621d83ce4e9795fd108d9d310e244",
("d0976ed292ac24fcf1590d1ea195077c74b05471", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.023"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.006",
("ec746cd6af066f62d9bf8d3b2f89174783ff4e3c", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.024"), ),
("570d9e1e84178e32fd867171d4b3aaecda1fd4fb", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.025"), (
("c29ccc7467a75b2cae3d7f2e9fbbb2ab276cb8ac", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.026"), "de6ed9c2b0ee80ca879aae8ba7923cc93217d811",
("08406a51146d88e208704ce058c060a1e44efa50", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.027"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.007",
("199aedad733a78ea1e7d47def9c71c6fd5795e02", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.028"), ),
("db856a068f92fb4f01f410bba42c7271de0f231a", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.029"), (
("e3c0135f16c6c9d25a09dcb4f99a685438a84740", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.030"), "234283c47dacfcd4450d836c52c25f3e807fc5f2",
("e51b8bb9c0ae4339f98b4f21e6d29b825109f0ac", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.031"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.008",
("be5e80cbc49b59b31ae33c30576ef0e1a162d84e", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.032"), ),
("501df58e3ff55fcfd75b93dab57566dc536948b8", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.033"), (
("1a114875811a8cdcb8d85a9f6dbee78be3e05131", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.034"), "4e6b67a688639bb72f8cd81782eaba604a8d32a6",
("465d824e7ee46448369182c0c28646d155a2249b", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.035"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.009",
("37f341b1b266d143eb73138c31cfff3201b9d619", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.036"), ),
("9e7d8255987a8a77a90e0d4b55c8fd38b9fb5694", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.037"), (
("54886755630cb080a53098cb1b6c951c6714a143", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.038"), "4165a51389777c8af8e6253d87bdacb877e8b3b0",
("4b7cbb0154697be795034f7a49712e882a97197a", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.039"), "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.010",
("c8e1e565a0e7a1f6ff1dbfcefe677aa74a41d2f2", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.040"), ),
(
"34322e7009780d97ef5bd02bf2f2c7a31f00baff",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.011",
),
(
"48c5be3b2ca9d6108d525da6a03e91d93a95dbac",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.012",
),
(
"87573172f506a189c2ebc633856fe11a2e9cd213",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.013",
),
(
"6ab2c9e508e9278d5129f023e018725c4a7c69e8",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.014",
),
(
"4f84df831ef46dce5d3ab3e21817687a2d8c12d0",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.015",
),
(
"e69bfb079885c299cb81080ef88b1b8b57158aa6",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.016",
),
(
"5f764ba788ee273981cf211b242c29b49ca22c5e",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.017",
),
(
"b6aa81a959525363223494830c1e7307d4c4bae6",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.018",
),
(
"91ddcf43c7bf113a6f2528b857c7ec22a50a148a",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.019",
),
(
"fa1b29273dd77b9a7494983a2f9ae52654b931d7",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.020",
),
(
"1113aef4f5e2be2f7fbf2d54b6c710c1c0e7135f",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.021",
),
(
"ce6420d5d0b6b5135ba559f83e1a82d4d615c470",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.022",
),
(
"d0976ed292ac24fcf1590d1ea195077c74b05471",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.023",
),
(
"ec746cd6af066f62d9bf8d3b2f89174783ff4e3c",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.024",
),
(
"570d9e1e84178e32fd867171d4b3aaecda1fd4fb",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.025",
),
(
"c29ccc7467a75b2cae3d7f2e9fbbb2ab276cb8ac",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.026",
),
(
"08406a51146d88e208704ce058c060a1e44efa50",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.027",
),
(
"199aedad733a78ea1e7d47def9c71c6fd5795e02",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.028",
),
(
"db856a068f92fb4f01f410bba42c7271de0f231a",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.029",
),
(
"e3c0135f16c6c9d25a09dcb4f99a685438a84740",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.030",
),
(
"e51b8bb9c0ae4339f98b4f21e6d29b825109f0ac",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.031",
),
(
"be5e80cbc49b59b31ae33c30576ef0e1a162d84e",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.032",
),
(
"501df58e3ff55fcfd75b93dab57566dc536948b8",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.033",
),
(
"1a114875811a8cdcb8d85a9f6dbee78be3e05131",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.034",
),
(
"465d824e7ee46448369182c0c28646d155a2249b",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.035",
),
(
"37f341b1b266d143eb73138c31cfff3201b9d619",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.036",
),
(
"9e7d8255987a8a77a90e0d4b55c8fd38b9fb5694",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.037",
),
(
"54886755630cb080a53098cb1b6c951c6714a143",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.038",
),
(
"4b7cbb0154697be795034f7a49712e882a97197a",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.039",
),
(
"c8e1e565a0e7a1f6ff1dbfcefe677aa74a41d2f2",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip.040",
),
] ]
def _download_and_preprocess_data(csv_url, target_dir): def _download_and_preprocess_data(csv_url, target_dir):
dataset_sources = os.path.join(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", "data.txt") dataset_sources = os.path.join(
target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", "data.txt"
)
if os.path.exists(dataset_sources): if os.path.exists(dataset_sources):
return dataset_sources return dataset_sources
# Making path absolute # Making path absolute
target_dir = os.path.abspath(target_dir) target_dir = os.path.abspath(target_dir)
csv_ref = requests.get(csv_url).text.split('\r\n')[1:-1] csv_ref = requests.get(csv_url).text.split("\r\n")[1:-1]
for part in csv_ref: for part in csv_ref:
part_filename = requests.head(part).headers.get("Content-Disposition").split(" ")[1].split("=")[1].replace('"', "") part_filename = (
requests.head(part)
.headers.get("Content-Disposition")
.split(" ")[1]
.split("=")[1]
.replace('"', "")
)
if not os.path.exists(os.path.join(target_dir, part_filename)): if not os.path.exists(os.path.join(target_dir, part_filename)):
part_path = maybe_download(part_filename, target_dir, part) part_path = maybe_download(part_filename, target_dir, part)
@ -126,10 +255,18 @@ def _download_and_preprocess_data(csv_url, target_dir):
assert csum == sha1 assert csum == sha1
# Conditionally extract data # Conditionally extract data
_maybe_extract(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020", "transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip", "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020.zip") _maybe_extract(
target_dir,
"transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020",
"transcriptionsxml_audiomp3_mefr_ccpmf_2012-2020_2.zip",
"transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020.zip",
)
# Produce source text for extraction / conversion # Produce source text for extraction / conversion
return _maybe_create_sources(os.path.join(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020")) return _maybe_create_sources(
os.path.join(target_dir, "transcriptionsXML_audioMP3_MEFR_CCPMF_2012-2020")
)
def _maybe_extract(target_dir, extracted_data, archive, final): def _maybe_extract(target_dir, extracted_data, archive, final):
# If target_dir/extracted_data does not exist, extract archive in target_dir # If target_dir/extracted_data does not exist, extract archive in target_dir
@ -147,7 +284,10 @@ def _maybe_extract(target_dir, extracted_data, archive, final):
subprocess.check_call(cmdline, shell=True, cwd=target_dir) subprocess.check_call(cmdline, shell=True, cwd=target_dir)
assert os.path.exists(archive_path) assert os.path.exists(archive_path)
print('No directory "%s" - extracting archive %s ...' % (extracted_path, archive_path)) print(
'No directory "%s" - extracting archive %s ...'
% (extracted_path, archive_path)
)
with zipfile.ZipFile(archive_path) as zip_f: with zipfile.ZipFile(archive_path) as zip_f:
zip_f.extractall(extracted_path) zip_f.extractall(extracted_path)
@ -156,6 +296,7 @@ def _maybe_extract(target_dir, extracted_data, archive, final):
else: else:
print('Found directory "%s" - not extracting it from archive.' % extracted_path) print('Found directory "%s" - not extracting it from archive.' % extracted_path)
def _maybe_create_sources(dir): def _maybe_create_sources(dir):
dataset_sources = os.path.join(dir, "data.txt") dataset_sources = os.path.join(dir, "data.txt")
MP3 = glob(os.path.join(dir, "**", "*.mp3")) MP3 = glob(os.path.join(dir, "**", "*.mp3"))
@ -168,8 +309,8 @@ def _maybe_create_sources(dir):
for f_xml in XML: for f_xml in XML:
b_mp3 = os.path.splitext(os.path.basename(f_mp3))[0] b_mp3 = os.path.splitext(os.path.basename(f_mp3))[0]
b_xml = os.path.splitext(os.path.basename(f_xml))[0] b_xml = os.path.splitext(os.path.basename(f_xml))[0]
a_mp3 = b_mp3.split('_') a_mp3 = b_mp3.split("_")
a_xml = b_xml.split('_') a_xml = b_xml.split("_")
score = 0 score = 0
date_mp3 = a_mp3[0] date_mp3 = a_mp3[0]
date_xml = a_xml[0] date_xml = a_xml[0]
@ -178,7 +319,7 @@ def _maybe_create_sources(dir):
continue continue
for i in range(min(len(a_mp3), len(a_xml))): for i in range(min(len(a_mp3), len(a_xml))):
if (a_mp3[i] == a_xml[i]): if a_mp3[i] == a_xml[i]:
score += 1 score += 1
if score >= 1: if score >= 1:
@ -187,7 +328,7 @@ def _maybe_create_sources(dir):
# sort by score # sort by score
MP3_XML_Scores.sort(key=lambda x: x[2], reverse=True) MP3_XML_Scores.sort(key=lambda x: x[2], reverse=True)
for s_mp3, s_xml, score in MP3_XML_Scores: for s_mp3, s_xml, score in MP3_XML_Scores:
#print(s_mp3, s_xml, score) # print(s_mp3, s_xml, score)
if score not in MP3_XML_Fin: if score not in MP3_XML_Fin:
MP3_XML_Fin[score] = {} MP3_XML_Fin[score] = {}
@ -208,13 +349,14 @@ def _maybe_create_sources(dir):
if os.path.getsize(mp3) > 0 and os.path.getsize(xml) > 0: if os.path.getsize(mp3) > 0 and os.path.getsize(xml) > 0:
mp3 = os.path.relpath(mp3, dir) mp3 = os.path.relpath(mp3, dir)
xml = os.path.relpath(xml, dir) xml = os.path.relpath(xml, dir)
ds.write('{},{},{:0.2e}\n'.format(xml, mp3, 2.5e-4)) ds.write("{},{},{:0.2e}\n".format(xml, mp3, 2.5e-4))
else: else:
print("Empty file {} or {}".format(mp3, xml), file=sys.stderr) print("Empty file {} or {}".format(mp3, xml), file=sys.stderr)
print("Missing XML pairs:", MP3, file=sys.stderr) print("Missing XML pairs:", MP3, file=sys.stderr)
return dataset_sources return dataset_sources
def maybe_normalize_for_digits(label): def maybe_normalize_for_digits(label):
# first, try to identify numbers like "50 000", "260 000" # first, try to identify numbers like "50 000", "260 000"
if " " in label: if " " in label:
@ -234,30 +376,44 @@ def maybe_normalize_for_digits(label):
date_or_time = re.compile(r"(\d{1,2}):(\d{2}):?(\d{2})?") date_or_time = re.compile(r"(\d{1,2}):(\d{2}):?(\d{2})?")
maybe_date_or_time = date_or_time.findall(s) maybe_date_or_time = date_or_time.findall(s)
if len(maybe_date_or_time) > 0: if len(maybe_date_or_time) > 0:
maybe_hours = maybe_date_or_time[0][0] maybe_hours = maybe_date_or_time[0][0]
maybe_minutes = maybe_date_or_time[0][1] maybe_minutes = maybe_date_or_time[0][1]
maybe_seconds = maybe_date_or_time[0][2] maybe_seconds = maybe_date_or_time[0][2]
if len(maybe_seconds) > 0: if len(maybe_seconds) > 0:
label = label.replace("{}:{}:{}".format(maybe_hours, maybe_minutes, maybe_seconds), "{} heures {} minutes et {} secondes".format(maybe_hours, maybe_minutes, maybe_seconds)) label = label.replace(
"{}:{}:{}".format(
maybe_hours, maybe_minutes, maybe_seconds
),
"{} heures {} minutes et {} secondes".format(
maybe_hours, maybe_minutes, maybe_seconds
),
)
else: else:
label = label.replace("{}:{}".format(maybe_hours, maybe_minutes), "{} heures et {} minutes".format(maybe_hours, maybe_minutes)) label = label.replace(
"{}:{}".format(maybe_hours, maybe_minutes),
"{} heures et {} minutes".format(
maybe_hours, maybe_minutes
),
)
new_label = [] new_label = []
# pylint: disable=too-many-nested-blocks # pylint: disable=too-many-nested-blocks
for s in label.split(" "): for s in label.split(" "):
if any(i.isdigit() for i in s): if any(i.isdigit() for i in s):
s = s.replace(",", ".") # num2words requires "." for floats s = s.replace(",", ".") # num2words requires "." for floats
s = s.replace("\"", "") # clean some data, num2words would choke on 1959" s = s.replace('"', "") # clean some data, num2words would choke on 1959"
last_c = s[-1] last_c = s[-1]
if not last_c.isdigit(): # num2words will choke on "0.6.", "24 ?" if not last_c.isdigit(): # num2words will choke on "0.6.", "24 ?"
s = s[:-1] s = s[:-1]
if any(i.isalpha() for i in s): # So we have any(isdigit()) **and** any(sialpha), like "3D" if any(
i.isalpha() for i in s
): # So we have any(isdigit()) **and** any(sialpha), like "3D"
ns = [] ns = []
for c in s: for c in s:
nc = c nc = c
if c.isdigit(): # convert "3" to "trois-" if c.isdigit(): # convert "3" to "trois-"
try: try:
nc = num2words(c, lang="fr") + "-" nc = num2words(c, lang="fr") + "-"
except decimal.InvalidOperation as ex: except decimal.InvalidOperation as ex:
@ -274,22 +430,36 @@ def maybe_normalize_for_digits(label):
new_label.append(s) new_label.append(s)
return " ".join(new_label) return " ".join(new_label)
def maybe_normalize_for_specials_chars(label): def maybe_normalize_for_specials_chars(label):
label = label.replace("%", "pourcents") label = label.replace("%", "pourcents")
label = label.replace("/", ", ") # clean intervals like 2019/2022 to "2019 2022" label = label.replace("/", ", ") # clean intervals like 2019/2022 to "2019 2022"
label = label.replace("-", ", ") # clean intervals like 70-80 to "70 80" label = label.replace("-", ", ") # clean intervals like 70-80 to "70 80"
label = label.replace("+", " plus ") # clean + and make it speakable label = label.replace("+", " plus ") # clean + and make it speakable
label = label.replace("", " euros ") # clean euro symbol and make it speakable label = label.replace("", " euros ") # clean euro symbol and make it speakable
label = label.replace("., ", ", ") # clean some strange "4.0., " (20181017_Innovation.xml) label = label.replace(
label = label.replace("°", " degré ") # clean some strange "°5" (20181210_EtatsGeneraux-1000_fre_750_und.xml) "., ", ", "
label = label.replace("...", ".") # remove ellipsis ) # clean some strange "4.0., " (20181017_Innovation.xml)
label = label.replace("..", ".") # remove broken ellipsis label = label.replace(
label = label.replace("", "mètre-carrés") # 20150616_Defi_Climat_3_wmv_0_fre_minefi.xml "°", " degré "
label = label.replace("[end]", "") # broken tag in 20150123_Entretiens_Tresor_PGM_wmv_0_fre_minefi.xml ) # clean some strange "°5" (20181210_EtatsGeneraux-1000_fre_750_und.xml)
label = label.replace(u'\xB8c', " ç") # strange cedilla in 20150417_Printemps_Economie_2_wmv_0_fre_minefi.xml label = label.replace("...", ".") # remove ellipsis
label = label.replace("C0²", "CO 2") # 20121016_Syteme_sante_copie_wmv_0_fre_minefi.xml label = label.replace("..", ".") # remove broken ellipsis
label = label.replace(
"", "mètre-carrés"
) # 20150616_Defi_Climat_3_wmv_0_fre_minefi.xml
label = label.replace(
"[end]", ""
) # broken tag in 20150123_Entretiens_Tresor_PGM_wmv_0_fre_minefi.xml
label = label.replace(
u"\xB8c", " ç"
) # strange cedilla in 20150417_Printemps_Economie_2_wmv_0_fre_minefi.xml
label = label.replace(
"C0²", "CO 2"
) # 20121016_Syteme_sante_copie_wmv_0_fre_minefi.xml
return label return label
def maybe_normalize_for_anglicisms(label): def maybe_normalize_for_anglicisms(label):
label = label.replace("B2B", "B to B") label = label.replace("B2B", "B to B")
label = label.replace("B2C", "B to C") label = label.replace("B2C", "B to C")
@ -297,12 +467,14 @@ def maybe_normalize_for_anglicisms(label):
label = label.replace("@", "at ") label = label.replace("@", "at ")
return label return label
def maybe_normalize(label): def maybe_normalize(label):
label = maybe_normalize_for_specials_chars(label) label = maybe_normalize_for_specials_chars(label)
label = maybe_normalize_for_anglicisms(label) label = maybe_normalize_for_anglicisms(label)
label = maybe_normalize_for_digits(label) label = maybe_normalize_for_digits(label)
return label return label
def one_sample(sample): def one_sample(sample):
file_size = -1 file_size = -1
frames = 0 frames = 0
@ -316,14 +488,33 @@ def one_sample(sample):
label = label_filter_fun(sample[5]) label = label_filter_fun(sample[5])
sample_id = sample[6] sample_id = sample[6]
_wav_filename = os.path.basename(audio_source.replace(".wav", "_{:06}.wav".format(sample_id))) _wav_filename = os.path.basename(
audio_source.replace(".wav", "_{:06}.wav".format(sample_id))
)
wav_fullname = os.path.join(target_dir, dataset_basename, _wav_filename) wav_fullname = os.path.join(target_dir, dataset_basename, _wav_filename)
if not os.path.exists(wav_fullname): if not os.path.exists(wav_fullname):
subprocess.check_output(["ffmpeg", "-i", audio_source, "-ss", str(start_time), "-t", str(duration), "-c", "copy", wav_fullname], stdin=subprocess.DEVNULL, stderr=subprocess.STDOUT) subprocess.check_output(
[
"ffmpeg",
"-i",
audio_source,
"-ss",
str(start_time),
"-t",
str(duration),
"-c",
"copy",
wav_fullname,
],
stdin=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
)
file_size = os.path.getsize(wav_fullname) file_size = os.path.getsize(wav_fullname)
frames = int(subprocess.check_output(["soxi", "-s", wav_fullname], stderr=subprocess.STDOUT)) frames = int(
subprocess.check_output(["soxi", "-s", wav_fullname], stderr=subprocess.STDOUT)
)
_counter = get_counter() _counter = get_counter()
_rows = [] _rows = []
@ -334,13 +525,13 @@ def one_sample(sample):
elif label is None: elif label is None:
# Excluding samples that failed on label validation # Excluding samples that failed on label validation
_counter["invalid_label"] += 1 _counter["invalid_label"] += 1
elif int(frames/SAMPLE_RATE*1000/10/2) < len(str(label)): elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)):
# Excluding samples that are too short to fit the transcript # Excluding samples that are too short to fit the transcript
_counter["too_short"] += 1 _counter["too_short"] += 1
elif frames/SAMPLE_RATE < MIN_SECS: elif frames / SAMPLE_RATE < MIN_SECS:
# Excluding samples that are too short # Excluding samples that are too short
_counter["too_short"] += 1 _counter["too_short"] += 1
elif frames/SAMPLE_RATE > MAX_SECS: elif frames / SAMPLE_RATE > MAX_SECS:
# Excluding very long samples to keep a reasonable batch-size # Excluding very long samples to keep a reasonable batch-size
_counter["too_long"] += 1 _counter["too_long"] += 1
else: else:
@ -352,56 +543,71 @@ def one_sample(sample):
return (_counter, _rows) return (_counter, _rows)
def _maybe_import_data(xml_file, audio_source, target_dir, rel_tol=1e-1): def _maybe_import_data(xml_file, audio_source, target_dir, rel_tol=1e-1):
dataset_basename = os.path.splitext(os.path.split(xml_file)[1])[0] dataset_basename = os.path.splitext(os.path.split(xml_file)[1])[0]
wav_root = os.path.join(target_dir, dataset_basename) wav_root = os.path.join(target_dir, dataset_basename)
if not os.path.exists(wav_root): if not os.path.exists(wav_root):
os.makedirs(wav_root) os.makedirs(wav_root)
source_frames = int(subprocess.check_output(["soxi", "-s", audio_source], stderr=subprocess.STDOUT)) source_frames = int(
subprocess.check_output(["soxi", "-s", audio_source], stderr=subprocess.STDOUT)
)
print("Source audio length: %s" % secs_to_hours(source_frames / SAMPLE_RATE)) print("Source audio length: %s" % secs_to_hours(source_frames / SAMPLE_RATE))
# Get audiofile path and transcript for each sentence in tsv # Get audiofile path and transcript for each sentence in tsv
samples = [] samples = []
tree = ET.parse(xml_file) tree = ET.parse(xml_file)
root = tree.getroot() root = tree.getroot()
seq_id = 0 seq_id = 0
this_time = 0.0 this_time = 0.0
this_duration = 0.0 this_duration = 0.0
prev_time = 0.0 prev_time = 0.0
prev_duration = 0.0 prev_duration = 0.0
this_text = "" this_text = ""
for child in root: for child in root:
if child.tag == "row": if child.tag == "row":
cur_time = float(child.attrib["timestamp"]) cur_time = float(child.attrib["timestamp"])
cur_duration = float(child.attrib["timedur"]) cur_duration = float(child.attrib["timedur"])
cur_text = child.text cur_text = child.text
if this_time == 0.0: if this_time == 0.0:
this_time = cur_time this_time = cur_time
delta = cur_time - (prev_time + prev_duration) delta = cur_time - (prev_time + prev_duration)
# rel_tol value is made from trial/error to try and compromise between: # rel_tol value is made from trial/error to try and compromise between:
# - cutting enough to skip missing words # - cutting enough to skip missing words
# - not too short, not too long sentences # - not too short, not too long sentences
is_close = math.isclose(cur_time, this_time + this_duration, rel_tol=rel_tol) is_close = math.isclose(
is_short = ((this_duration + cur_duration + delta) < MAX_SECS) cur_time, this_time + this_duration, rel_tol=rel_tol
)
is_short = (this_duration + cur_duration + delta) < MAX_SECS
# when the previous element is close enough **and** this does not # when the previous element is close enough **and** this does not
# go over MAX_SECS, we append content # go over MAX_SECS, we append content
if (is_close and is_short): if is_close and is_short:
this_duration += cur_duration + delta this_duration += cur_duration + delta
this_text += cur_text this_text += cur_text
else: else:
samples.append((audio_source, target_dir, dataset_basename, this_time, this_duration, this_text, seq_id)) samples.append(
(
audio_source,
target_dir,
dataset_basename,
this_time,
this_duration,
this_text,
seq_id,
)
)
this_time = cur_time this_time = cur_time
this_duration = cur_duration this_duration = cur_duration
this_text = cur_text this_text = cur_text
seq_id += 1 seq_id += 1
prev_time = cur_time prev_time = cur_time
prev_duration = cur_duration prev_duration = cur_duration
# Keep track of how many samples are good vs. problematic # Keep track of how many samples are good vs. problematic
@ -425,21 +631,27 @@ def _maybe_import_data(xml_file, audio_source, target_dir, rel_tol=1e-1):
assert len(_rows) == imported_samples assert len(_rows) == imported_samples
print_import_report(_counter, SAMPLE_RATE, MAX_SECS) print_import_report(_counter, SAMPLE_RATE, MAX_SECS)
print("Import efficiency: %.1f%%" % ((_counter["total_time"] / source_frames)*100)) print(
"Import efficiency: %.1f%%" % ((_counter["total_time"] / source_frames) * 100)
)
print("") print("")
return _counter, _rows return _counter, _rows
def _maybe_convert_wav(mp3_filename, _wav_filename): def _maybe_convert_wav(mp3_filename, _wav_filename):
if not os.path.exists(_wav_filename): if not os.path.exists(_wav_filename):
print("Converting {} to WAV file: {}".format(mp3_filename, _wav_filename)) print("Converting {} to WAV file: {}".format(mp3_filename, _wav_filename))
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS, bitdepth=BIT_DEPTH) transformer.convert(
samplerate=SAMPLE_RATE, n_channels=CHANNELS, bitdepth=BIT_DEPTH
)
try: try:
transformer.build(mp3_filename, _wav_filename) transformer.build(mp3_filename, _wav_filename)
except sox.core.SoxError: except sox.core.SoxError:
pass pass
def write_general_csv(target_dir, _rows, _counter): def write_general_csv(target_dir, _rows, _counter):
target_csv_template = os.path.join(target_dir, "ccpmf_{}.csv") target_csv_template = os.path.join(target_dir, "ccpmf_{}.csv")
with open(target_csv_template.format("train"), "w") as train_csv_file: # 80% with open(target_csv_template.format("train"), "w") as train_csv_file: # 80%
@ -461,7 +673,13 @@ def write_general_csv(target_dir, _rows, _counter):
writer = dev_writer writer = dev_writer
else: else:
writer = train_writer writer = train_writer
writer.writerow({"wav_filename": item[0], "wav_filesize": item[1], "transcript": item[2]}) writer.writerow(
{
"wav_filename": item[0],
"wav_filesize": item[1],
"transcript": item[2],
}
)
print("") print("")
print("~~~~ FINAL STATISTICS ~~~~") print("~~~~ FINAL STATISTICS ~~~~")
@ -469,11 +687,21 @@ def write_general_csv(target_dir, _rows, _counter):
print("~~~~ (FINAL STATISTICS) ~~~~") print("~~~~ (FINAL STATISTICS) ~~~~")
print("") print("")
if __name__ == "__main__": if __name__ == "__main__":
PARSER = get_importers_parser(description="Import XML from Conference Centre for Economics, France") PARSER = get_importers_parser(
description="Import XML from Conference Centre for Economics, France"
)
PARSER.add_argument("target_dir", help="Destination directory") PARSER.add_argument("target_dir", help="Destination directory")
PARSER.add_argument("--filter_alphabet", help="Exclude samples with characters not in provided alphabet") PARSER.add_argument(
PARSER.add_argument("--normalize", action="store_true", help="Converts diacritic characters to their base ones") "--filter_alphabet",
help="Exclude samples with characters not in provided alphabet",
)
PARSER.add_argument(
"--normalize",
action="store_true",
help="Converts diacritic characters to their base ones",
)
PARAMS = PARSER.parse_args() PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS) validate_label = get_validate_label(PARAMS)
@ -481,9 +709,11 @@ if __name__ == "__main__":
def label_filter_fun(label): def label_filter_fun(label):
if PARAMS.normalize: if PARAMS.normalize:
label = unicodedata.normalize("NFKD", label.strip()) \ label = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore") .decode("ascii", "ignore")
)
label = maybe_normalize(label) label = maybe_normalize(label)
label = validate_label(label) label = validate_label(label)
if ALPHABET and label: if ALPHABET and label:
@ -493,7 +723,9 @@ if __name__ == "__main__":
label = None label = None
return label return label
dataset_sources = _download_and_preprocess_data(csv_url=DATASET_RELEASE_CSV, target_dir=PARAMS.target_dir) dataset_sources = _download_and_preprocess_data(
csv_url=DATASET_RELEASE_CSV, target_dir=PARAMS.target_dir
)
sources_root_dir = os.path.dirname(dataset_sources) sources_root_dir = os.path.dirname(dataset_sources)
all_counter = get_counter() all_counter = get_counter()
all_rows = [] all_rows = []
@ -504,9 +736,14 @@ if __name__ == "__main__":
this_mp3 = os.path.join(sources_root_dir, d[1]) this_mp3 = os.path.join(sources_root_dir, d[1])
this_rel = float(d[2]) this_rel = float(d[2])
wav_filename = os.path.join(sources_root_dir, os.path.splitext(os.path.basename(this_mp3))[0] + ".wav") wav_filename = os.path.join(
sources_root_dir,
os.path.splitext(os.path.basename(this_mp3))[0] + ".wav",
)
_maybe_convert_wav(this_mp3, wav_filename) _maybe_convert_wav(this_mp3, wav_filename)
counter, rows = _maybe_import_data(this_xml, wav_filename, sources_root_dir, this_rel) counter, rows = _maybe_import_data(
this_xml, wav_filename, sources_root_dir, this_rel
)
all_counter += counter all_counter += counter
all_rows += rows all_rows += rows

View File

@ -1,15 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
import csv import csv
import os import os
import sys
import subprocess import subprocess
import sys
import tarfile import tarfile
from glob import glob from glob import glob
from multiprocessing import Pool from multiprocessing import Pool
import progressbar import progressbar
import sox import sox
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,

View File

@ -14,7 +14,7 @@ from multiprocessing import Pool
import progressbar import progressbar
import sox import sox
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR from coqui_stt_training.util.downloader import SIMPLE_BAR
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,
@ -23,7 +23,6 @@ from coqui_stt_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from coqui_stt_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -41,7 +40,11 @@ class LabelFilter:
def filter(self, label): def filter(self, label):
if self.normalize: if self.normalize:
label = unicodedata.normalize("NFKD", label.strip()).encode("ascii", "ignore").decode("ascii", "ignore") label = (
unicodedata.normalize("NFKD", label.strip())
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
label = self.validate_fun(label) label = self.validate_fun(label)
if self.alphabet and label and not self.alphabet.CanEncode(label): if self.alphabet and label and not self.alphabet.CanEncode(label):
label = None label = None
@ -97,7 +100,15 @@ def one_sample(sample):
return (counter, rows) return (counter, rows)
def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_every_character=None, rows=None, exclude=None): def _maybe_convert_set(
dataset,
tsv_dir,
audio_dir,
filter_obj,
space_after_every_character=None,
rows=None,
exclude=None,
):
exclude_transcripts = set() exclude_transcripts = set()
exclude_speakers = set() exclude_speakers = set()
if exclude is not None: if exclude is not None:
@ -116,7 +127,13 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
with open(input_tsv, encoding="utf-8") as input_tsv_file: with open(input_tsv, encoding="utf-8") as input_tsv_file:
reader = csv.DictReader(input_tsv_file, delimiter="\t") reader = csv.DictReader(input_tsv_file, delimiter="\t")
for row in reader: for row in reader:
samples.append((os.path.join(audio_dir, row["path"]), row["sentence"], row["client_id"])) samples.append(
(
os.path.join(audio_dir, row["path"]),
row["sentence"],
row["client_id"],
)
)
counter = get_counter() counter = get_counter()
num_samples = len(samples) num_samples = len(samples)
@ -124,7 +141,9 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
print("Importing mp3 files...") print("Importing mp3 files...")
pool = Pool(initializer=init_worker, initargs=(PARAMS,)) pool = Pool(initializer=init_worker, initargs=(PARAMS,))
bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(max_value=num_samples, widgets=SIMPLE_BAR)
for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): for i, processed in enumerate(
pool.imap_unordered(one_sample, samples), start=1
):
counter += processed[0] counter += processed[0]
rows += processed[1] rows += processed[1]
bar.update(i) bar.update(i)
@ -169,12 +188,20 @@ def _maybe_convert_set(dataset, tsv_dir, audio_dir, filter_obj, space_after_ever
def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False):
exclude = [] exclude = []
for dataset in ["test", "dev", "train", "validated", "other"]: for dataset in ["test", "dev", "train", "validated", "other"]:
set_samples = _maybe_convert_set(dataset, tsv_dir, audio_dir, space_after_every_character) set_samples = _maybe_convert_set(
dataset, tsv_dir, audio_dir, space_after_every_character
)
if dataset in ["test", "dev"]: if dataset in ["test", "dev"]:
exclude += set_samples exclude += set_samples
if dataset == "validated": if dataset == "validated":
_maybe_convert_set("train-all", tsv_dir, audio_dir, space_after_every_character, _maybe_convert_set(
rows=set_samples, exclude=exclude) "train-all",
tsv_dir,
audio_dir,
space_after_every_character,
rows=set_samples,
exclude=exclude,
)
def _maybe_convert_wav(mp3_filename, wav_filename): def _maybe_convert_wav(mp3_filename, wav_filename):
@ -212,7 +239,9 @@ def parse_args():
def main(): def main():
audio_dir = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips") audio_dir = (
PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, "clips")
)
_preprocess_data(PARAMS.tsv_dir, audio_dir, PARAMS.space_after_every_character) _preprocess_data(PARAMS.tsv_dir, audio_dir, PARAMS.space_after_every_character)

View File

@ -10,7 +10,6 @@ import unicodedata
import librosa import librosa
import pandas import pandas
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile
from coqui_stt_training.util.importers import validate_label_eng as validate_label from coqui_stt_training.util.importers import validate_label_eng as validate_label
# Prerequisite: Having the sph2pipe tool in your PATH: # Prerequisite: Having the sph2pipe tool in your PATH:
@ -261,8 +260,7 @@ def _split_sets(filelist):
def get_sample_size(population_size): def get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error """calculates the sample size for a 99% confidence and 1% margin of error"""
"""
margin_of_error = 0.01 margin_of_error = 0.01
fraction_picking = 0.50 fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99% z_score = 2.58 # Corresponds to confidence level 99%

View File

@ -5,7 +5,6 @@ import tarfile
import numpy as np import numpy as np
import pandas import pandas
from coqui_stt_training.util.importers import get_importers_parser from coqui_stt_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]

View File

@ -9,10 +9,9 @@ import urllib
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
from sox import Transformer
import swifter import swifter
from coqui_stt_training.util.importers import get_importers_parser, get_validate_label from coqui_stt_training.util.importers import get_importers_parser, get_validate_label
from sox import Transformer
__version__ = "0.1.0" __version__ = "0.1.0"
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)

View File

@ -3,7 +3,6 @@ import os
import sys import sys
import pandas import pandas
from coqui_stt_training.util.downloader import maybe_download from coqui_stt_training.util.downloader import maybe_download

View File

@ -9,10 +9,10 @@ import unicodedata
import pandas import pandas
import progressbar import progressbar
from sox import Transformer
from tensorflow.python.platform import gfile
from coqui_stt_training.util.downloader import maybe_download from coqui_stt_training.util.downloader import maybe_download
from sox import Transformer
from tensorflow.python.platform import gfile
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000

View File

@ -11,7 +11,7 @@ from multiprocessing import Pool
import progressbar import progressbar
import sox import sox
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,
@ -20,7 +20,6 @@ from coqui_stt_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from coqui_stt_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -137,9 +136,15 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% with open(
with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% target_csv_template.format("train"), "w", encoding="utf-8", newline=""
with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% ) as train_csv_file: # 80%
with open(
target_csv_template.format("dev"), "w", encoding="utf-8", newline=""
) as dev_csv_file: # 10%
with open(
target_csv_template.format("test"), "w", encoding="utf-8", newline=""
) as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)
@ -179,7 +184,9 @@ def _maybe_convert_sets(target_dir, extracted_data):
def _maybe_convert_wav(ogg_filename, wav_filename): def _maybe_convert_wav(ogg_filename, wav_filename):
if not os.path.exists(wav_filename): if not os.path.exists(wav_filename):
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.convert(samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH) transformer.convert(
samplerate=SAMPLE_RATE, n_channels=N_CHANNELS, bitdepth=BITDEPTH
)
try: try:
transformer.build(ogg_filename, wav_filename) transformer.build(ogg_filename, wav_filename)
except sox.core.SoxError as ex: except sox.core.SoxError as ex:

View File

@ -9,7 +9,7 @@ from glob import glob
from multiprocessing import Pool from multiprocessing import Pool
import progressbar import progressbar
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,
@ -18,7 +18,6 @@ from coqui_stt_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from coqui_stt_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -60,9 +59,20 @@ def one_sample(sample):
file_size = -1 file_size = -1
frames = 0 frames = 0
if os.path.exists(wav_filename): if os.path.exists(wav_filename):
tmp_filename = os.path.splitext(wav_filename)[0]+'.tmp.wav' tmp_filename = os.path.splitext(wav_filename)[0] + ".tmp.wav"
subprocess.check_call( subprocess.check_call(
['sox', wav_filename, '-r', str(SAMPLE_RATE), '-c', '1', '-b', '16', tmp_filename], stderr=subprocess.STDOUT [
"sox",
wav_filename,
"-r",
str(SAMPLE_RATE),
"-c",
"1",
"-b",
"16",
tmp_filename,
],
stderr=subprocess.STDOUT,
) )
os.rename(tmp_filename, wav_filename) os.rename(tmp_filename, wav_filename)
file_size = os.path.getsize(wav_filename) file_size = os.path.getsize(wav_filename)
@ -138,9 +148,15 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% with open(
with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% target_csv_template.format("train"), "w", encoding="utf-8", newline=""
with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% ) as train_csv_file: # 80%
with open(
target_csv_template.format("dev"), "w", encoding="utf-8", newline=""
) as dev_csv_file: # 10%
with open(
target_csv_template.format("test"), "w", encoding="utf-8", newline=""
) as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)

View File

@ -5,7 +5,6 @@ import tarfile
import wave import wave
import pandas import pandas
from coqui_stt_training.util.importers import get_importers_parser from coqui_stt_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]

View File

@ -2,10 +2,9 @@
import argparse import argparse
import ctypes import ctypes
import os import os
from pathlib import Path
import pandas import pandas
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm

View File

@ -6,7 +6,6 @@ import tarfile
import numpy as np import numpy as np
import pandas import pandas
from coqui_stt_training.util.importers import get_importers_parser from coqui_stt_training.util.importers import get_importers_parser
COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"] COLUMN_NAMES = ["wav_filename", "wav_filesize", "transcript"]

View File

@ -8,7 +8,7 @@ from glob import glob
from multiprocessing import Pool from multiprocessing import Pool
import progressbar import progressbar
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,
@ -17,7 +17,6 @@ from coqui_stt_training.util.importers import (
get_validate_label, get_validate_label,
print_import_report, print_import_report,
) )
from coqui_stt_ctcdecoder import Alphabet
FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"]
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
@ -157,9 +156,15 @@ def _maybe_convert_sets(target_dir, extracted_data):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% with open(
with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% target_csv_template.format("train"), "w", encoding="utf-8", newline=""
with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% ) as train_csv_file: # 80%
with open(
target_csv_template.format("dev"), "w", encoding="utf-8", newline=""
) as dev_csv_file: # 10%
with open(
target_csv_template.format("test"), "w", encoding="utf-8", newline=""
) as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)

View File

@ -16,7 +16,6 @@ import librosa
import pandas import pandas
import requests import requests
import soundfile # <= Has an external dependency on libsndfile import soundfile # <= Has an external dependency on libsndfile
from coqui_stt_training.util.importers import validate_label_eng as validate_label from coqui_stt_training.util.importers import validate_label_eng as validate_label
# ARCHIVE_NAME refers to ISIP alignments from 01/29/03 # ARCHIVE_NAME refers to ISIP alignments from 01/29/03
@ -315,8 +314,7 @@ def _split_sets(filelist):
def get_sample_size(population_size): def get_sample_size(population_size):
"""calculates the sample size for a 99% confidence and 1% margin of error """calculates the sample size for a 99% confidence and 1% margin of error"""
"""
margin_of_error = 0.01 margin_of_error = 0.01
fraction_picking = 0.50 fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99% z_score = 2.58 # Corresponds to confidence level 99%

View File

@ -21,10 +21,9 @@ from multiprocessing.pool import ThreadPool
import progressbar import progressbar
import sox import sox
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import validate_label_eng as validate_label from coqui_stt_training.util.importers import validate_label_eng as validate_label
from coqui_stt_ctcdecoder import Alphabet
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
SWC_ARCHIVE = "SWC_{language}.tar" SWC_ARCHIVE = "SWC_{language}.tar"
@ -173,7 +172,6 @@ def in_alphabet(alphabet, c):
return alphabet.CanEncode(c) if alphabet else True return alphabet.CanEncode(c) if alphabet else True
ALPHABETS = {} ALPHABETS = {}
@ -202,8 +200,16 @@ def label_filter(label, language):
dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else "" dont_normalize = DONT_NORMALIZE[language] if language in DONT_NORMALIZE else ""
alphabet = get_alphabet(language) alphabet = get_alphabet(language)
for c in label: for c in label:
if CLI_ARGS.normalize and c not in dont_normalize and not in_alphabet(alphabet, c): if (
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") CLI_ARGS.normalize
and c not in dont_normalize
and not in_alphabet(alphabet, c)
):
c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if not in_alphabet(alphabet, sc): if not in_alphabet(alphabet, sc):
return None, "illegal character" return None, "illegal character"

View File

@ -7,11 +7,11 @@ from glob import glob
from os import makedirs, path, remove, rmdir from os import makedirs, path, remove, rmdir
import pandas import pandas
from sox import Transformer
from tensorflow.python.platform import gfile
from coqui_stt_training.util.downloader import maybe_download from coqui_stt_training.util.downloader import maybe_download
from coqui_stt_training.util.stm import parse_stm_file from coqui_stt_training.util.stm import parse_stm_file
from sox import Transformer
from tensorflow.python.platform import gfile
def _download_and_preprocess_data(data_dir): def _download_and_preprocess_data(data_dir):

View File

@ -8,7 +8,6 @@ from multiprocessing import Pool
import progressbar import progressbar
import sox import sox
import unidecode import unidecode
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
@ -132,9 +131,15 @@ def _maybe_convert_sets(target_dir, extracted_data, english_compatible=False):
pool.close() pool.close()
pool.join() pool.join()
with open(target_csv_template.format("train"), "w", encoding="utf-8", newline="") as train_csv_file: # 80% with open(
with open(target_csv_template.format("dev"), "w", encoding="utf-8", newline="") as dev_csv_file: # 10% target_csv_template.format("train"), "w", encoding="utf-8", newline=""
with open(target_csv_template.format("test"), "w", encoding="utf-8", newline="") as test_csv_file: # 10% ) as train_csv_file: # 80%
with open(
target_csv_template.format("dev"), "w", encoding="utf-8", newline=""
) as dev_csv_file: # 10%
with open(
target_csv_template.format("test"), "w", encoding="utf-8", newline=""
) as test_csv_file: # 10%
train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES) train_writer = csv.DictWriter(train_csv_file, fieldnames=FIELDNAMES)
train_writer.writeheader() train_writer.writeheader()
dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES) dev_writer = csv.DictWriter(dev_csv_file, fieldnames=FIELDNAMES)

View File

@ -13,10 +13,9 @@ import xml.etree.ElementTree as ET
from collections import Counter from collections import Counter
import progressbar import progressbar
from coqui_stt_ctcdecoder import Alphabet
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import validate_label_eng as validate_label from coqui_stt_training.util.importers import validate_label_eng as validate_label
from coqui_stt_ctcdecoder import Alphabet
TUDA_VERSION = "v2" TUDA_VERSION = "v2"
TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION) TUDA_PACKAGE = "german-speechdata-package-{}".format(TUDA_VERSION)
@ -55,7 +54,11 @@ def check_and_prepare_sentence(sentence):
chars = [] chars = []
for c in sentence: for c in sentence:
if CLI_ARGS.normalize and c not in "äöüß" and not in_alphabet(c): if CLI_ARGS.normalize and c not in "äöüß" and not in_alphabet(c):
c = unicodedata.normalize("NFKD", c).encode("ascii", "ignore").decode("ascii", "ignore") c = (
unicodedata.normalize("NFKD", c)
.encode("ascii", "ignore")
.decode("ascii", "ignore")
)
for sc in c: for sc in c:
if not in_alphabet(c): if not in_alphabet(c):
return None return None
@ -118,7 +121,7 @@ def write_csvs(extracted):
sentence = list(meta.iter("cleaned_sentence"))[0].text sentence = list(meta.iter("cleaned_sentence"))[0].text
sentence = check_and_prepare_sentence(sentence) sentence = check_and_prepare_sentence(sentence)
if sentence is None: if sentence is None:
reasons['alphabet filter'] += 1 reasons["alphabet filter"] += 1
continue continue
for wav_name in wav_names: for wav_name in wav_names:
sample_counter += 1 sample_counter += 1

View File

@ -10,7 +10,6 @@ from zipfile import ZipFile
import librosa import librosa
import progressbar import progressbar
from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download from coqui_stt_training.util.downloader import SIMPLE_BAR, maybe_download
from coqui_stt_training.util.importers import ( from coqui_stt_training.util.importers import (
get_counter, get_counter,

View File

@ -13,9 +13,10 @@ from os import makedirs, path
import pandas import pandas
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from tensorflow.python.platform import gfile
from coqui_stt_training.util.downloader import maybe_download from coqui_stt_training.util.downloader import maybe_download
from tensorflow.python.platform import gfile
"""The number of jobs to run in parallel""" """The number of jobs to run in parallel"""
NUM_PARALLEL = 8 NUM_PARALLEL = 8

View File

@ -4,14 +4,26 @@ Tool for playing (and augmenting) single samples or samples from Sample Database
Use "python3 play.py -h" for help Use "python3 play.py -h" for help
""" """
import os
import sys
import random
import argparse import argparse
import os
import random
import sys
from coqui_stt_training.util.audio import get_loadable_audio_type_from_extension, AUDIO_TYPE_PCM, AUDIO_TYPE_WAV from coqui_stt_training.util.audio import (
from coqui_stt_training.util.sample_collections import SampleList, LabeledSample, samples_from_source AUDIO_TYPE_PCM,
from coqui_stt_training.util.augmentations import parse_augmentations, apply_sample_augmentations, SampleAugmentation AUDIO_TYPE_WAV,
get_loadable_audio_type_from_extension,
)
from coqui_stt_training.util.augmentations import (
SampleAugmentation,
apply_sample_augmentations,
parse_augmentations,
)
from coqui_stt_training.util.sample_collections import (
LabeledSample,
SampleList,
samples_from_source,
)
def get_samples_in_play_order(): def get_samples_in_play_order():
@ -43,11 +55,13 @@ def play_collection():
if any(not isinstance(a, SampleAugmentation) for a in augmentations): if any(not isinstance(a, SampleAugmentation) for a in augmentations):
print("Warning: Some of the augmentations cannot be simulated by this command.") print("Warning: Some of the augmentations cannot be simulated by this command.")
samples = get_samples_in_play_order() samples = get_samples_in_play_order()
samples = apply_sample_augmentations(samples, samples = apply_sample_augmentations(
audio_type=AUDIO_TYPE_PCM, samples,
augmentations=augmentations, audio_type=AUDIO_TYPE_PCM,
process_ahead=0, augmentations=augmentations,
clock=CLI_ARGS.clock) process_ahead=0,
clock=CLI_ARGS.clock,
)
for sample in samples: for sample in samples:
if not CLI_ARGS.quiet: if not CLI_ARGS.quiet:
print('Sample "{}"'.format(sample.sample_id), file=sys.stderr) print('Sample "{}"'.format(sample.sample_id), file=sys.stderr)
@ -57,10 +71,12 @@ def play_collection():
sample.change_audio_type(AUDIO_TYPE_WAV) sample.change_audio_type(AUDIO_TYPE_WAV)
sys.stdout.buffer.write(sample.audio.getvalue()) sys.stdout.buffer.write(sample.audio.getvalue())
return return
wave_obj = simpleaudio.WaveObject(sample.audio, wave_obj = simpleaudio.WaveObject(
sample.audio_format.channels, sample.audio,
sample.audio_format.width, sample.audio_format.channels,
sample.audio_format.rate) sample.audio_format.width,
sample.audio_format.rate,
)
play_obj = wave_obj.play() play_obj = wave_obj.play()
play_obj.wait_done() play_obj.wait_done()
@ -70,7 +86,9 @@ def handle_args():
description="Tool for playing (and augmenting) single samples or samples from Sample Databases (SDB files) " description="Tool for playing (and augmenting) single samples or samples from Sample Databases (SDB files) "
"and Coqui STT CSV files" "and Coqui STT CSV files"
) )
parser.add_argument("source", help="Sample DB, CSV or WAV file to play samples from") parser.add_argument(
"source", help="Sample DB, CSV or WAV file to play samples from"
)
parser.add_argument( parser.add_argument(
"--start", "--start",
type=int, type=int,
@ -90,7 +108,7 @@ def handle_args():
) )
parser.add_argument( parser.add_argument(
"--augment", "--augment",
action='append', action="append",
help="Add an augmentation operation", help="Add an augmentation operation",
) )
parser.add_argument( parser.add_argument(
@ -98,8 +116,8 @@ def handle_args():
type=float, type=float,
default=0.5, default=0.5,
help="Simulates clock value used for augmentations during training." help="Simulates clock value used for augmentations during training."
"Ranges from 0.0 (representing parameter start values) to" "Ranges from 0.0 (representing parameter start values) to"
"1.0 (representing parameter end values)", "1.0 (representing parameter end values)",
) )
parser.add_argument( parser.add_argument(
"--pipe", "--pipe",
@ -120,7 +138,9 @@ if __name__ == "__main__":
try: try:
import simpleaudio import simpleaudio
except ModuleNotFoundError: except ModuleNotFoundError:
print('Unless using the --pipe flag, play.py requires Python package "simpleaudio" for playing samples') print(
'Unless using the --pipe flag, play.py requires Python package "simpleaudio" for playing samples'
)
sys.exit(1) sys.exit(1)
try: try:
play_collection() play_collection()

View File

@ -8,4 +8,3 @@ This directory contains language-specific data files. Most importantly, you will
2. A script used to generate a binary n-gram language model: ``data/lm/generate_lm.py``. 2. A script used to generate a binary n-gram language model: ``data/lm/generate_lm.py``.
For more information on how to build these resources from scratch, see the ``External scorer scripts`` section on `stt.readthedocs.io <https://stt.readthedocs.io/>`_. For more information on how to build these resources from scratch, see the ``External scorer scripts`` section on `stt.readthedocs.io <https://stt.readthedocs.io/>`_.

View File

@ -78,20 +78,20 @@ def build_lm(args, data_lower, vocab_str):
print("\nCreating ARPA file ...") print("\nCreating ARPA file ...")
lm_path = os.path.join(args.output_dir, "lm.arpa") lm_path = os.path.join(args.output_dir, "lm.arpa")
subargs = [ subargs = [
os.path.join(args.kenlm_bins, "lmplz"), os.path.join(args.kenlm_bins, "lmplz"),
"--order", "--order",
str(args.arpa_order), str(args.arpa_order),
"--temp_prefix", "--temp_prefix",
args.output_dir, args.output_dir,
"--memory", "--memory",
args.max_arpa_memory, args.max_arpa_memory,
"--text", "--text",
data_lower, data_lower,
"--arpa", "--arpa",
lm_path, lm_path,
"--prune", "--prune",
*args.arpa_prune.split("|"), *args.arpa_prune.split("|"),
] ]
if args.discount_fallback: if args.discount_fallback:
subargs += ["--discount_fallback"] subargs += ["--discount_fallback"]
subprocess.check_call(subargs) subprocess.check_call(subargs)

View File

@ -22,21 +22,27 @@
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('../')) sys.path.insert(0, os.path.abspath("../"))
autodoc_mock_imports = ['stt'] autodoc_mock_imports = ["stt"]
# This is in fact only relevant on ReadTheDocs, but we want to run the same way # This is in fact only relevant on ReadTheDocs, but we want to run the same way
# on our CI as in RTD to avoid regressions on RTD that we would not catch on CI # on our CI as in RTD to avoid regressions on RTD that we would not catch on CI
import subprocess import subprocess
parent = subprocess.check_output("cd ../ && pwd", shell=True).decode().strip() parent = subprocess.check_output("cd ../ && pwd", shell=True).decode().strip()
os.environ["PATH"] = os.path.join(parent, 'node_modules', '.bin') + ':' + os.environ["PATH"] os.environ["PATH"] = (
subprocess.check_call('cd ../ && npm install typedoc@0.17.4 typescript@3.8.3 @types/node@13.9.x', shell=True) os.path.join(parent, "node_modules", ".bin") + ":" + os.environ["PATH"]
subprocess.check_call('env', shell=True) )
subprocess.check_call('which typedoc', shell=True) subprocess.check_call(
subprocess.check_call('cd ../ && doxygen doc/doxygen-c.conf', shell=True) "cd ../ && npm install typedoc@0.17.4 typescript@3.8.3 @types/node@13.9.x",
subprocess.check_call('cd ../ && doxygen doc/doxygen-java.conf', shell=True) shell=True,
subprocess.check_call('cd ../ && doxygen doc/doxygen-dotnet.conf', shell=True) )
subprocess.check_call("env", shell=True)
subprocess.check_call("which typedoc", shell=True)
subprocess.check_call("cd ../ && doxygen doc/doxygen-c.conf", shell=True)
subprocess.check_call("cd ../ && doxygen doc/doxygen-java.conf", shell=True)
subprocess.check_call("cd ../ && doxygen doc/doxygen-dotnet.conf", shell=True)
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
@ -44,11 +50,11 @@ import semver
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = u'Coqui STT' project = u"Coqui STT"
copyright = '2021 Coqui GmbH, 2020 DeepSpeech authors, 2019-2020 Mozilla Corporation' copyright = "2021 Coqui GmbH, 2020 DeepSpeech authors, 2019-2020 Mozilla Corporation"
author = 'Coqui GmbH' author = "Coqui GmbH"
with open('../VERSION', 'r') as ver: with open("../VERSION", "r") as ver:
v = ver.read().strip() v = ver.read().strip()
vv = semver.parse(v) vv = semver.parse(v)
@ -56,7 +62,7 @@ vv = semver.parse(v)
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# The short X.Y version # The short X.Y version
version = '{}.{}'.format(vv['major'], vv['minor']) version = "{}.{}".format(vv["major"], vv["minor"])
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = v release = v
@ -68,22 +74,22 @@ release = v
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.extlinks', "sphinx.ext.extlinks",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.mathjax', "sphinx.ext.mathjax",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx_js', "sphinx_js",
'sphinx_csharp', "sphinx_csharp",
'breathe', "breathe",
'recommonmark', "recommonmark",
] ]
breathe_projects = { breathe_projects = {
"stt-c": "xml-c/", "stt-c": "xml-c/",
"stt-java": "xml-java/", "stt-java": "xml-java/",
"stt-dotnet": "xml-dotnet/", "stt-dotnet": "xml-dotnet/",
} }
js_source_path = "../native_client/javascript/index.ts" js_source_path = "../native_client/javascript/index.ts"
@ -91,16 +97,16 @@ js_language = "typescript"
jsdoc_config_path = "../native_client/javascript/tsconfig.json" jsdoc_config_path = "../native_client/javascript/tsconfig.json"
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['.templates'] templates_path = [".templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
# source_suffix = ['.rst', '.md'] # source_suffix = ['.rst', '.md']
source_suffix = '.rst' source_suffix = ".rst"
# The main toctree document. # The main toctree document.
master_doc = 'index' master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
@ -112,10 +118,10 @@ language = None
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path # This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store', 'node_modules', 'examples'] exclude_patterns = [".build", "Thumbs.db", ".DS_Store", "node_modules", "examples"]
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing. # If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False todo_include_todos = False
@ -128,18 +134,18 @@ add_module_names = False
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'furo' html_theme = "furo"
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['.static'] html_static_path = [".static"]
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'STTdoc' htmlhelp_basename = "STTdoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
@ -148,15 +154,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
@ -166,8 +169,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'STT.tex', u'Coqui STT Documentation', (master_doc, "STT.tex", u"Coqui STT Documentation", u"Coqui GmbH", "manual"),
u'Coqui GmbH', 'manual'),
] ]
@ -175,10 +177,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "stt", u"Coqui STT Documentation", [author], 1)]
(master_doc, 'stt', u'Coqui STT Documentation',
[author], 1)
]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
@ -187,16 +186,21 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'STT', u'Coqui STT Documentation', (
author, 'STT', 'One line description of project.', master_doc,
'Miscellaneous'), "STT",
u"Coqui STT Documentation",
author,
"STT",
"One line description of project.",
"Miscellaneous",
),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'https://docs.python.org/': None} intersphinx_mapping = {"https://docs.python.org/": None}
extlinks = {'github': ('https://github.com/coqui-ai/STT/blob/v{}/%s'.format(release), extlinks = {
'%s')} "github": ("https://github.com/coqui-ai/STT/blob/v{}/%s".format(release), "%s")
}

View File

@ -2,11 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
if __name__ == '__main__': if __name__ == "__main__":
try: try:
from coqui_stt_training import evaluate as ds_evaluate from coqui_stt_training import evaluate as ds_evaluate
except ImportError: except ImportError:
print('Training package is not installed. See training documentation.') print("Training package is not installed. See training documentation.")
raise raise
ds_evaluate.run_script() ds_evaluate.run_script()

View File

@ -2,22 +2,22 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import absl.app
import argparse import argparse
import numpy as np
import wave
import csv import csv
import os import os
import sys import sys
import wave
from functools import partial
from multiprocessing import JoinableQueue, Manager, Process, cpu_count
from stt import Model import absl.app
import numpy as np
from coqui_stt_training.util.evaluate_tools import calculate_and_print_report from coqui_stt_training.util.evaluate_tools import calculate_and_print_report
from coqui_stt_training.util.flags import create_flags from coqui_stt_training.util.flags import create_flags
from functools import partial from six.moves import range, zip
from multiprocessing import JoinableQueue, Process, cpu_count, Manager from stt import Model
from six.moves import zip, range
r''' r"""
This module should be self-contained: This module should be self-contained:
- build libstt.so with TFLite: - build libstt.so with TFLite:
- bazel build [...] --define=runtime=tflite [...] //native_client:libstt.so - bazel build [...] --define=runtime=tflite [...] //native_client:libstt.so
@ -27,10 +27,11 @@ This module should be self-contained:
- pip install -r requirements_eval_tflite.txt - pip install -r requirements_eval_tflite.txt
Then run with a TFLite model, a scorer and a CSV test file Then run with a TFLite model, a scorer and a CSV test file
''' """
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask): def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_mask)
ds = Model(model) ds = Model(model)
ds.enableExternalScorer(scorer) ds.enableExternalScorer(scorer)
@ -38,29 +39,41 @@ def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
try: try:
msg = queue_in.get() msg = queue_in.get()
filename = msg['filename'] filename = msg["filename"]
fin = wave.open(filename, 'rb') fin = wave.open(filename, "rb")
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close() fin.close()
decoded = ds.stt(audio) decoded = ds.stt(audio)
queue_out.put({'wav': filename, 'prediction': decoded, 'ground_truth': msg['transcript']}) queue_out.put(
{
"wav": filename,
"prediction": decoded,
"ground_truth": msg["transcript"],
}
)
except FileNotFoundError as ex: except FileNotFoundError as ex:
print('FileNotFoundError: ', ex) print("FileNotFoundError: ", ex)
print(queue_out.qsize(), end='\r') # Update the current progress print(queue_out.qsize(), end="\r") # Update the current progress
queue_in.task_done() queue_in.task_done()
def main(args, _): def main(args, _):
manager = Manager() manager = Manager()
work_todo = JoinableQueue() # this is where we are going to store input data work_todo = JoinableQueue() # this is where we are going to store input data
work_done = manager.Queue() # this where we are gonna push them out work_done = manager.Queue() # this where we are gonna push them out
processes = [] processes = []
for i in range(args.proc): for i in range(args.proc):
worker_process = Process(target=tflite_worker, args=(args.model, args.scorer, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i)) worker_process = Process(
worker_process.start() # Launch reader() as a separate python process target=tflite_worker,
args=(args.model, args.scorer, work_todo, work_done, i),
daemon=True,
name="tflite_process_{}".format(i),
)
worker_process.start() # Launch reader() as a separate python process
processes.append(worker_process) processes.append(worker_process)
print([x.name for x in processes]) print([x.name for x in processes])
@ -71,56 +84,75 @@ def main(args, _):
losses = [] losses = []
wav_filenames = [] wav_filenames = []
with open(args.csv, 'r') as csvfile: with open(args.csv, "r") as csvfile:
csvreader = csv.DictReader(csvfile) csvreader = csv.DictReader(csvfile)
count = 0 count = 0
for row in csvreader: for row in csvreader:
count += 1 count += 1
# Relative paths are relative to the folder the CSV file is in # Relative paths are relative to the folder the CSV file is in
if not os.path.isabs(row['wav_filename']): if not os.path.isabs(row["wav_filename"]):
row['wav_filename'] = os.path.join(os.path.dirname(args.csv), row['wav_filename']) row["wav_filename"] = os.path.join(
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']}) os.path.dirname(args.csv), row["wav_filename"]
wav_filenames.extend(row['wav_filename']) )
work_todo.put(
{"filename": row["wav_filename"], "transcript": row["transcript"]}
)
wav_filenames.extend(row["wav_filename"])
print('Totally %d wav entries found in csv\n' % count) print("Totally %d wav entries found in csv\n" % count)
work_todo.join() work_todo.join()
print('\nTotally %d wav file transcripted' % work_done.qsize()) print("\nTotally %d wav file transcripted" % work_done.qsize())
while not work_done.empty(): while not work_done.empty():
msg = work_done.get() msg = work_done.get()
losses.append(0.0) losses.append(0.0)
ground_truths.append(msg['ground_truth']) ground_truths.append(msg["ground_truth"])
predictions.append(msg['prediction']) predictions.append(msg["prediction"])
wavlist.append(msg['wav']) wavlist.append(msg["wav"])
# Print test summary # Print test summary
_ = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, args.csv) _ = calculate_and_print_report(
wav_filenames, ground_truths, predictions, losses, args.csv
)
if args.dump: if args.dump:
with open(args.dump + '.txt', 'w') as ftxt, open(args.dump + '.out', 'w') as fout: with open(args.dump + ".txt", "w") as ftxt, open(
args.dump + ".out", "w"
) as fout:
for wav, txt, out in zip(wavlist, ground_truths, predictions): for wav, txt, out in zip(wavlist, ground_truths, predictions):
ftxt.write('%s %s\n' % (wav, txt)) ftxt.write("%s %s\n" % (wav, txt))
fout.write('%s %s\n' % (wav, out)) fout.write("%s %s\n" % (wav, out))
print('Reference texts dumped to %s.txt' % args.dump) print("Reference texts dumped to %s.txt" % args.dump)
print('Transcription dumped to %s.out' % args.dump) print("Transcription dumped to %s.out" % args.dump)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy') parser = argparse.ArgumentParser(description="Computing TFLite accuracy")
parser.add_argument('--model', required=True, parser.add_argument(
help='Path to the model (protocol buffer binary file)') "--model", required=True, help="Path to the model (protocol buffer binary file)"
parser.add_argument('--scorer', required=True, )
help='Path to the external scorer file') parser.add_argument(
parser.add_argument('--csv', required=True, "--scorer", required=True, help="Path to the external scorer file"
help='Path to the CSV source file') )
parser.add_argument('--proc', required=False, default=cpu_count(), type=int, parser.add_argument("--csv", required=True, help="Path to the CSV source file")
help='Number of processes to spawn, defaulting to number of CPUs') parser.add_argument(
parser.add_argument('--dump', required=False, "--proc",
help='Path to dump the results as text file, with one line for each wav: "wav transcription".') required=False,
default=cpu_count(),
type=int,
help="Number of processes to spawn, defaulting to number of CPUs",
)
parser.add_argument(
"--dump",
required=False,
help='Path to dump the results as text file, with one line for each wav: "wav transcription".',
)
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
# Reconstruct argv for absl.flags # Reconstruct argv for absl.flags
sys.argv = [sys.argv[0]] + unknown sys.argv = [sys.argv[0]] + unknown
return args return args
if __name__ == '__main__':
if __name__ == "__main__":
create_flags() create_flags()
absl.app.run(partial(main, parse_args())) absl.app.run(partial(main, parse_args()))

View File

@ -2,35 +2,39 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import sys
import absl.app import absl.app
import optuna import optuna
import sys from coqui_stt_ctcdecoder import Scorer
import tensorflow.compat.v1 as tfv1
from coqui_stt_training.evaluate import evaluate from coqui_stt_training.evaluate import evaluate
from coqui_stt_training.train import create_model from coqui_stt_training.train import create_model
from coqui_stt_training.util.config import Config, initialize_globals from coqui_stt_training.util.config import Config, initialize_globals
from coqui_stt_training.util.flags import create_flags, FLAGS
from coqui_stt_training.util.logging import log_error
from coqui_stt_training.util.evaluate_tools import wer_cer_batch from coqui_stt_training.util.evaluate_tools import wer_cer_batch
from coqui_stt_ctcdecoder import Scorer from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import log_error
import tensorflow.compat.v1 as tfv1
def character_based(): def character_based():
is_character_based = False is_character_based = False
if FLAGS.scorer_path: if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) scorer = Scorer(
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
)
is_character_based = scorer.is_utf8_mode() is_character_based = scorer.is_utf8_mode()
return is_character_based return is_character_based
def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
is_character_based = trial.study.user_attrs['is_character_based'] def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max)
is_character_based = trial.study.user_attrs["is_character_based"]
samples = [] samples = []
for step, test_file in enumerate(FLAGS.test_files.split(',')): for step, test_file in enumerate(FLAGS.test_files.split(",")):
tfv1.reset_default_graph() tfv1.reset_default_graph()
current_samples = evaluate([test_file], create_model) current_samples = evaluate([test_file], create_model)
@ -47,12 +51,15 @@ def objective(trial):
wer, cer = wer_cer_batch(samples) wer, cer = wer_cer_batch(samples)
return cer if is_character_based else wer return cer if is_character_based else wer
def main(_): def main(_):
initialize_globals() initialize_globals()
if not FLAGS.test_files: if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via ' log_error(
'the --test_files flag.') "You need to specify what files to use for evaluation via "
"the --test_files flag."
)
sys.exit(1) sys.exit(1)
is_character_based = character_based() is_character_based = character_based()
@ -60,11 +67,15 @@ def main(_):
study = optuna.create_study() study = optuna.create_study()
study.set_user_attr("is_character_based", is_character_based) study.set_user_attr("is_character_based", is_character_based)
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'], print(
study.best_params['lm_beta'], "Best params: lm_alpha={} and lm_beta={} with WER={}".format(
study.best_value)) study.best_params["lm_alpha"],
study.best_params["lm_beta"],
study.best_value,
)
)
if __name__ == '__main__': if __name__ == "__main__":
create_flags() create_flags()
absl.app.run(main) absl.app.run(main)

View File

@ -20,4 +20,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.

View File

@ -1,17 +1,18 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self from . import swigwrapper # pylint: disable=import-self
# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle # This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
# string encoding explicitly, here and throughout this file. # string encoding explicitly, here and throughout this file.
__version__ = swigwrapper.__version__.decode('utf-8') __version__ = swigwrapper.__version__.decode("utf-8")
# Hack: import error codes by matching on their names, as SWIG unfortunately # Hack: import error codes by matching on their names, as SWIG unfortunately
# does not support binding enums to Python in a scoped manner yet. # does not support binding enums to Python in a scoped manner yet.
for symbol in dir(swigwrapper): for symbol in dir(swigwrapper):
if symbol.startswith('STT_ERR_'): if symbol.startswith("STT_ERR_"):
globals()[symbol] = getattr(swigwrapper, symbol) globals()[symbol] = getattr(swigwrapper, symbol)
class Scorer(swigwrapper.Scorer): class Scorer(swigwrapper.Scorer):
"""Wrapper for Scorer. """Wrapper for Scorer.
@ -23,130 +24,140 @@ class Scorer(swigwrapper.Scorer):
:alphabet: Alphabet :alphabet: Alphabet
:type scorer_path: basestring :type scorer_path: basestring
""" """
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None): def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
super(Scorer, self).__init__() super(Scorer, self).__init__()
# Allow bare initialization # Allow bare initialization
if alphabet: if alphabet:
assert alpha is not None, 'alpha parameter is required' assert alpha is not None, "alpha parameter is required"
assert beta is not None, 'beta parameter is required' assert beta is not None, "beta parameter is required"
assert scorer_path, 'scorer_path parameter is required' assert scorer_path, "scorer_path parameter is required"
err = self.init(scorer_path.encode('utf-8'), alphabet) err = self.init(scorer_path.encode("utf-8"), alphabet)
if err != 0: if err != 0:
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err)) raise ValueError(
"Scorer initialization failed with error code 0x{:X}".format(err)
)
self.reset_params(alpha, beta) self.reset_params(alpha, beta)
class Alphabet(swigwrapper.Alphabet): class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor""" """Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path): def __init__(self, config_path):
super(Alphabet, self).__init__() super(Alphabet, self).__init__()
err = self.init(config_path.encode('utf-8')) err = self.init(config_path.encode("utf-8"))
if err != 0: if err != 0:
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err)) raise ValueError(
"Alphabet initialization failed with error code 0x{:X}".format(err)
)
def CanEncodeSingle(self, input): def CanEncodeSingle(self, input):
''' """
Returns true if the single character/output class has a corresponding label Returns true if the single character/output class has a corresponding label
in the alphabet. in the alphabet.
''' """
return super(Alphabet, self).CanEncodeSingle(input.encode('utf-8')) return super(Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def CanEncode(self, input): def CanEncode(self, input):
''' """
Returns true if the entire string can be encoded into labels in this Returns true if the entire string can be encoded into labels in this
alphabet. alphabet.
''' """
return super(Alphabet, self).CanEncode(input.encode('utf-8')) return super(Alphabet, self).CanEncode(input.encode("utf-8"))
def EncodeSingle(self, input): def EncodeSingle(self, input):
''' """
Encode a single character/output class into a label. Character must be in Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use `CanEncodeSingle` to test. the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
''' """
return super(Alphabet, self).EncodeSingle(input.encode('utf-8')) return super(Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input): def Encode(self, input):
''' """
Encode a sequence of character/output classes into a sequence of labels. Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint. Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test. `CanEncode` and `CanEncodeSingle` to test.
''' """
# Convert SWIG's UnsignedIntVec to a Python list # Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input.encode('utf-8')) res = super(Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res] return [el for el in res]
def DecodeSingle(self, input): def DecodeSingle(self, input):
res = super(Alphabet, self).DecodeSingle(input) res = super(Alphabet, self).DecodeSingle(input)
return res.decode('utf-8') return res.decode("utf-8")
def Decode(self, input): def Decode(self, input):
'''Decode a sequence of labels into a string.''' """Decode a sequence of labels into a string."""
res = super(Alphabet, self).Decode(input) res = super(Alphabet, self).Decode(input)
return res.decode('utf-8') return res.decode("utf-8")
class UTF8Alphabet(swigwrapper.UTF8Alphabet): class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor""" """Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self): def __init__(self):
super(UTF8Alphabet, self).__init__() super(UTF8Alphabet, self).__init__()
err = self.init(b'') err = self.init(b"")
if err != 0: if err != 0:
raise ValueError('UTF8Alphabet initialization failed with error code 0x{:X}'.format(err)) raise ValueError(
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
)
def CanEncodeSingle(self, input): def CanEncodeSingle(self, input):
''' """
Returns true if the single character/output class has a corresponding label Returns true if the single character/output class has a corresponding label
in the alphabet. in the alphabet.
''' """
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode('utf-8')) return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def CanEncode(self, input): def CanEncode(self, input):
''' """
Returns true if the entire string can be encoded into labels in this Returns true if the entire string can be encoded into labels in this
alphabet. alphabet.
''' """
return super(UTF8Alphabet, self).CanEncode(input.encode('utf-8')) return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
def EncodeSingle(self, input): def EncodeSingle(self, input):
''' """
Encode a single character/output class into a label. Character must be in Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use `CanEncodeSingle` to test. the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
''' """
return super(UTF8Alphabet, self).EncodeSingle(input.encode('utf-8')) return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input): def Encode(self, input):
''' """
Encode a sequence of character/output classes into a sequence of labels. Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint. Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test. `CanEncode` and `CanEncodeSingle` to test.
''' """
# Convert SWIG's UnsignedIntVec to a Python list # Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode('utf-8')) res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res] return [el for el in res]
def DecodeSingle(self, input): def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input) res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode('utf-8') return res.decode("utf-8")
def Decode(self, input): def Decode(self, input):
'''Decode a sequence of labels into a string.''' """Decode a sequence of labels into a string."""
res = super(UTF8Alphabet, self).Decode(input) res = super(UTF8Alphabet, self).Decode(input)
return res.decode('utf-8') return res.decode("utf-8")
def ctc_beam_search_decoder(
def ctc_beam_search_decoder(probs_seq, probs_seq,
alphabet, alphabet,
beam_size, beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
scorer=None, scorer=None,
hot_words=dict(), hot_words=dict(),
num_results=1): num_results=1,
):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
@ -175,22 +186,33 @@ def ctc_beam_search_decoder(probs_seq,
:rtype: list :rtype: list
""" """
beam_results = swigwrapper.ctc_beam_search_decoder( beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n, probs_seq,
scorer, hot_words, num_results) alphabet,
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] beam_size,
cutoff_prob,
cutoff_top_n,
scorer,
hot_words,
num_results,
)
beam_results = [
(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results
]
return beam_results return beam_results
def ctc_beam_search_decoder_batch(probs_seq, def ctc_beam_search_decoder_batch(
seq_lengths, probs_seq,
alphabet, seq_lengths,
beam_size, alphabet,
num_processes, beam_size,
cutoff_prob=1.0, num_processes,
cutoff_top_n=40, cutoff_prob=1.0,
scorer=None, cutoff_top_n=40,
hot_words=dict(), scorer=None,
num_results=1): hot_words=dict(),
num_results=1,
):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
@ -222,7 +244,18 @@ def ctc_beam_search_decoder_batch(probs_seq,
results, in descending order of the confidence. results, in descending order of the confidence.
:rtype: list :rtype: list
""" """
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results) batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
num_processes,
cutoff_prob,
cutoff_top_n,
scorer,
hot_words,
num_results,
)
batch_beam_results = [ batch_beam_results = [
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results for beam_results in batch_beam_results

View File

@ -6,84 +6,95 @@ import os
import shlex import shlex
import subprocess import subprocess
import sys import sys
from multiprocessing.dummy import Pool from multiprocessing.dummy import Pool
if sys.platform.startswith('win'): if sys.platform.startswith("win"):
ARGS = ['/nologo', '/D KENLM_MAX_ORDER=6', '/EHsc', '/source-charset:utf-8'] ARGS = ["/nologo", "/D KENLM_MAX_ORDER=6", "/EHsc", "/source-charset:utf-8"]
OPT_ARGS = ['/O2', '/MT', '/D NDEBUG'] OPT_ARGS = ["/O2", "/MT", "/D NDEBUG"]
DBG_ARGS = ['/Od', '/MTd', '/Zi', '/U NDEBUG', '/D DEBUG'] DBG_ARGS = ["/Od", "/MTd", "/Zi", "/U NDEBUG", "/D DEBUG"]
OPENFST_DIR = 'third_party/openfst-1.6.9-win' OPENFST_DIR = "third_party/openfst-1.6.9-win"
else: else:
ARGS = ['-fPIC', '-DKENLM_MAX_ORDER=6', '-std=c++11', '-Wno-unused-local-typedefs', '-Wno-sign-compare'] ARGS = [
OPT_ARGS = ['-O3', '-DNDEBUG'] "-fPIC",
DBG_ARGS = ['-O0', '-g', '-UNDEBUG', '-DDEBUG'] "-DKENLM_MAX_ORDER=6",
OPENFST_DIR = 'third_party/openfst-1.6.7' "-std=c++11",
"-Wno-unused-local-typedefs",
"-Wno-sign-compare",
]
OPT_ARGS = ["-O3", "-DNDEBUG"]
DBG_ARGS = ["-O0", "-g", "-UNDEBUG", "-DDEBUG"]
OPENFST_DIR = "third_party/openfst-1.6.7"
INCLUDES = [ INCLUDES = [
'..', "..",
'../kenlm', "../kenlm",
OPENFST_DIR + '/src/include', OPENFST_DIR + "/src/include",
'third_party/ThreadPool', "third_party/ThreadPool",
'third_party/object_pool' "third_party/object_pool",
] ]
KENLM_FILES = (glob.glob('../kenlm/util/*.cc') KENLM_FILES = (
+ glob.glob('../kenlm/lm/*.cc') glob.glob("../kenlm/util/*.cc")
+ glob.glob('../kenlm/util/double-conversion/*.cc')) + glob.glob("../kenlm/lm/*.cc")
+ glob.glob("../kenlm/util/double-conversion/*.cc")
)
KENLM_FILES += glob.glob(OPENFST_DIR + '/src/lib/*.cc') KENLM_FILES += glob.glob(OPENFST_DIR + "/src/lib/*.cc")
KENLM_FILES = [ KENLM_FILES = [
fn for fn in KENLM_FILES fn
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( for fn in KENLM_FILES
'unittest.cc')) if not (
fn.endswith("main.cc") or fn.endswith("test.cc") or fn.endswith("unittest.cc")
)
] ]
CTC_DECODER_FILES = [ CTC_DECODER_FILES = [
'ctc_beam_search_decoder.cpp', "ctc_beam_search_decoder.cpp",
'scorer.cpp', "scorer.cpp",
'path_trie.cpp', "path_trie.cpp",
'decoder_utils.cpp', "decoder_utils.cpp",
'workspace_status.cc', "workspace_status.cc",
'../alphabet.cc', "../alphabet.cc",
] ]
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
compiler = os.environ.get('CXX', 'g++') def build_archive(
if sys.platform.startswith('win'): srcs=[], out_name="", build_dir="temp_build/temp_build", debug=False, num_parallel=1
):
compiler = os.environ.get("CXX", "g++")
if sys.platform.startswith("win"):
compiler = '"{}"'.format(compiler) compiler = '"{}"'.format(compiler)
ar = os.environ.get('AR', 'ar') ar = os.environ.get("AR", "ar")
libexe = os.environ.get('LIBEXE', 'lib.exe') libexe = os.environ.get("LIBEXE", "lib.exe")
libtool = os.environ.get('LIBTOOL', 'libtool') libtool = os.environ.get("LIBTOOL", "libtool")
cflags = os.environ.get('CFLAGS', '') + os.environ.get('CXXFLAGS', '') cflags = os.environ.get("CFLAGS", "") + os.environ.get("CXXFLAGS", "")
args = ARGS + (DBG_ARGS if debug else OPT_ARGS) args = ARGS + (DBG_ARGS if debug else OPT_ARGS)
for file in srcs: for file in srcs:
outfile = os.path.join(build_dir, os.path.splitext(file)[0] + '.o') outfile = os.path.join(build_dir, os.path.splitext(file)[0] + ".o")
outdir = os.path.dirname(outfile) outdir = os.path.dirname(outfile)
if not os.path.exists(outdir): if not os.path.exists(outdir):
print('mkdir', outdir) print("mkdir", outdir)
os.makedirs(outdir) os.makedirs(outdir)
def build_one(file): def build_one(file):
outfile = os.path.join(build_dir, os.path.splitext(file)[0] + '.o') outfile = os.path.join(build_dir, os.path.splitext(file)[0] + ".o")
if os.path.exists(outfile): if os.path.exists(outfile):
return return
if sys.platform.startswith('win'): if sys.platform.startswith("win"):
file = '"{}"'.format(file.replace('\\', '/')) file = '"{}"'.format(file.replace("\\", "/"))
output = '/Fo"{}"'.format(outfile.replace('\\', '/')) output = '/Fo"{}"'.format(outfile.replace("\\", "/"))
else: else:
output = '-o ' + outfile output = "-o " + outfile
cmd = '{cc} -c {cflags} {args} {includes} {infile} {output}'.format( cmd = "{cc} -c {cflags} {args} {includes} {infile} {output}".format(
cc=compiler, cc=compiler,
cflags=cflags, cflags=cflags,
args=' '.join(args), args=" ".join(args),
includes=' '.join('-I' + i for i in INCLUDES), includes=" ".join("-I" + i for i in INCLUDES),
infile=file, infile=file,
output=output, output=output,
) )
@ -94,30 +105,28 @@ def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug
pool = Pool(num_parallel) pool = Pool(num_parallel)
obj_files = list(pool.imap_unordered(build_one, srcs)) obj_files = list(pool.imap_unordered(build_one, srcs))
if sys.platform.startswith('darwin'): if sys.platform.startswith("darwin"):
cmd = '{libtool} -static -o {outfile} {infiles}'.format( cmd = "{libtool} -static -o {outfile} {infiles}".format(
libtool=libtool, libtool=libtool,
outfile=out_name, outfile=out_name,
infiles=' '.join(obj_files), infiles=" ".join(obj_files),
) )
print(cmd) print(cmd)
subprocess.check_call(shlex.split(cmd)) subprocess.check_call(shlex.split(cmd))
elif sys.platform.startswith('win'): elif sys.platform.startswith("win"):
cmd = '"{libexe}" /OUT:"{outfile}" {infiles} /MACHINE:X64 /NOLOGO'.format( cmd = '"{libexe}" /OUT:"{outfile}" {infiles} /MACHINE:X64 /NOLOGO'.format(
libexe=libexe, libexe=libexe, outfile=out_name, infiles=" ".join(obj_files)
outfile=out_name, )
infiles=' '.join(obj_files)) cmd = cmd.replace("\\", "/")
cmd = cmd.replace('\\', '/')
print(cmd) print(cmd)
subprocess.check_call(shlex.split(cmd)) subprocess.check_call(shlex.split(cmd))
else: else:
cmd = '{ar} rcs {outfile} {infiles}'.format( cmd = "{ar} rcs {outfile} {infiles}".format(
ar=ar, ar=ar, outfile=out_name, infiles=" ".join(obj_files)
outfile=out_name,
infiles=' '.join(obj_files)
) )
print(cmd) print(cmd)
subprocess.check_call(shlex.split(cmd)) subprocess.check_call(shlex.split(cmd))
if __name__ == '__main__':
if __name__ == "__main__":
build_common() build_common()

View File

@ -13,4 +13,3 @@ bdist-dir=temp_build/temp_build
[install_lib] [install_lib]
build-dir=temp_build/temp_build build-dir=temp_build/temp_build

View File

@ -1,95 +1,105 @@
#!/usr/bin/env python #!/usr/bin/env python
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from distutils.command.build import build
from setuptools import setup, Extension, distutils
import argparse import argparse
import multiprocessing.pool import multiprocessing.pool
import os import os
import platform import platform
import sys import sys
from distutils.command.build import build
from build_archive import * from build_archive import *
from setuptools import Extension, distutils, setup
try: try:
import numpy import numpy
try: try:
numpy_include = numpy.get_include() numpy_include = numpy.get_include()
except AttributeError: except AttributeError:
numpy_include = numpy.get_numpy_include() numpy_include = numpy.get_numpy_include()
except ImportError: except ImportError:
numpy_include = '' numpy_include = ""
assert 'NUMPY_INCLUDE' in os.environ assert "NUMPY_INCLUDE" in os.environ
numpy_include = os.getenv('NUMPY_INCLUDE', numpy_include) numpy_include = os.getenv("NUMPY_INCLUDE", numpy_include)
numpy_min_ver = os.getenv('NUMPY_DEP_VERSION', '') numpy_min_ver = os.getenv("NUMPY_DEP_VERSION", "")
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--num_processes", "--num_processes",
default=1, default=1,
type=int, type=int,
help="Number of cpu processes to build package. (default: %(default)d)") help="Number of cpu processes to build package. (default: %(default)d)",
)
known_args, unknown_args = parser.parse_known_args() known_args, unknown_args = parser.parse_known_args()
debug = '--debug' in unknown_args debug = "--debug" in unknown_args
# reconstruct sys.argv to pass to setup below # reconstruct sys.argv to pass to setup below
sys.argv = [sys.argv[0]] + unknown_args sys.argv = [sys.argv[0]] + unknown_args
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
def maybe_rebuild(srcs, out_name, build_dir): def maybe_rebuild(srcs, out_name, build_dir):
if not os.path.exists(out_name): if not os.path.exists(out_name):
if not os.path.exists(build_dir): if not os.path.exists(build_dir):
os.makedirs(build_dir) os.makedirs(build_dir)
build_archive(srcs=srcs, build_archive(
out_name=out_name, srcs=srcs,
build_dir=build_dir, out_name=out_name,
num_parallel=known_args.num_processes, build_dir=build_dir,
debug=debug) num_parallel=known_args.num_processes,
debug=debug,
)
project_version = read('../../training/coqui_stt_training/VERSION').strip()
build_dir = 'temp_build/temp_build' project_version = read("../../training/coqui_stt_training/VERSION").strip()
if sys.platform.startswith('win'): build_dir = "temp_build/temp_build"
archive_ext = 'lib'
if sys.platform.startswith("win"):
archive_ext = "lib"
else: else:
archive_ext = 'a' archive_ext = "a"
third_party_build = 'third_party.{}'.format(archive_ext) third_party_build = "third_party.{}".format(archive_ext)
ctc_decoder_build = 'first_party.{}'.format(archive_ext) ctc_decoder_build = "first_party.{}".format(archive_ext)
maybe_rebuild(KENLM_FILES, third_party_build, build_dir) maybe_rebuild(KENLM_FILES, third_party_build, build_dir)
maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir) maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir)
decoder_module = Extension( decoder_module = Extension(
name='coqui_stt_ctcdecoder._swigwrapper', name="coqui_stt_ctcdecoder._swigwrapper",
sources=['swigwrapper.i'], sources=["swigwrapper.i"],
swig_opts=['-c++', '-extranative'], swig_opts=["-c++", "-extranative"],
language='c++', language="c++",
include_dirs=INCLUDES + [numpy_include], include_dirs=INCLUDES + [numpy_include],
extra_compile_args=ARGS + (DBG_ARGS if debug else OPT_ARGS), extra_compile_args=ARGS + (DBG_ARGS if debug else OPT_ARGS),
extra_link_args=[ctc_decoder_build, third_party_build], extra_link_args=[ctc_decoder_build, third_party_build],
) )
class BuildExtFirst(build): class BuildExtFirst(build):
sub_commands = [('build_ext', build.has_ext_modules), sub_commands = [
('build_py', build.has_pure_modules), ("build_ext", build.has_ext_modules),
('build_clib', build.has_c_libraries), ("build_py", build.has_pure_modules),
('build_scripts', build.has_scripts)] ("build_clib", build.has_c_libraries),
("build_scripts", build.has_scripts),
]
setup( setup(
name='coqui_stt_ctcdecoder', name="coqui_stt_ctcdecoder",
version=project_version, version=project_version,
description="""DS CTC decoder""", description="""DS CTC decoder""",
cmdclass = {'build': BuildExtFirst}, cmdclass={"build": BuildExtFirst},
ext_modules=[decoder_module], ext_modules=[decoder_module],
package_dir = {'coqui_stt_ctcdecoder': '.'}, package_dir={"coqui_stt_ctcdecoder": "."},
py_modules=['coqui_stt_ctcdecoder', 'coqui_stt_ctcdecoder.swigwrapper'], py_modules=["coqui_stt_ctcdecoder", "coqui_stt_ctcdecoder.swigwrapper"],
install_requires = ['numpy%s' % numpy_min_ver], install_requires=["numpy%s" % numpy_min_ver],
) )

View File

@ -11,5 +11,3 @@ org.gradle.jvmargs=-Xmx1536m
# This option should only be used with decoupled projects. More details, visit # This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true # org.gradle.parallel=true

View File

@ -70,4 +70,3 @@ public enum STT_Error_Codes {
private static int next = 0; private static int next = 0;
} }
} }

View File

@ -1,46 +1,44 @@
{ {
"targets": [ "targets": [
{
"target_name": "stt",
"sources": [ "stt_wrap.cxx" ],
"libraries": [
"$(LIBS)"
],
"include_dirs": [
"../"
],
"conditions": [
[ "OS=='mac'", {
"xcode_settings": {
"OTHER_CXXFLAGS": [
"-stdlib=libc++",
"-mmacosx-version-min=10.10"
],
"OTHER_LDFLAGS": [
"-stdlib=libc++",
"-mmacosx-version-min=10.10"
]
}
}
]
]
},
{
"target_name": "action_after_build",
"type": "none",
"dependencies": [ "<(module_name)" ],
"copies": [
{ {
"files": [ "<(PRODUCT_DIR)/<(module_name).node" ], "target_name": "stt",
"destination": "<(module_path)" "sources": ["stt_wrap.cxx"],
} "libraries": ["$(LIBS)"],
] "include_dirs": ["../"],
} "conditions": [
], [
"variables": { "OS=='mac'",
"build_v8_with_gn": 0, {
"v8_enable_pointer_compression": 0, "xcode_settings": {
"v8_enable_31bit_smis_on_64bit_arch": 0, "OTHER_CXXFLAGS": [
"enable_lto": 1 "-stdlib=libc++",
}, "-mmacosx-version-min=10.10",
],
"OTHER_LDFLAGS": [
"-stdlib=libc++",
"-mmacosx-version-min=10.10",
],
}
},
]
],
},
{
"target_name": "action_after_build",
"type": "none",
"dependencies": ["<(module_name)"],
"copies": [
{
"files": ["<(PRODUCT_DIR)/<(module_name).node"],
"destination": "<(module_path)",
}
],
},
],
"variables": {
"build_v8_with_gn": 0,
"v8_enable_pointer_compression": 0,
"v8_enable_31bit_smis_on_64bit_arch": 0,
"enable_lto": 1,
},
} }

View File

@ -1,27 +1,28 @@
import os import os
import platform import platform
#The API is not snake case which triggers linter errors # The API is not snake case which triggers linter errors
#pylint: disable=invalid-name # pylint: disable=invalid-name
if platform.system().lower() == "windows": if platform.system().lower() == "windows":
dslib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'lib') dslib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "lib")
# On Windows, we can't rely on RPATH being set to $ORIGIN/lib/ or on # On Windows, we can't rely on RPATH being set to $ORIGIN/lib/ or on
# @loader_path/lib # @loader_path/lib
if hasattr(os, 'add_dll_directory'): if hasattr(os, "add_dll_directory"):
# Starting with Python 3.8 this properly handles the problem # Starting with Python 3.8 this properly handles the problem
os.add_dll_directory(dslib_path) os.add_dll_directory(dslib_path)
else: else:
# Before Pythin 3.8 we need to change the PATH to include the proper # Before Pythin 3.8 we need to change the PATH to include the proper
# directory for the dynamic linker # directory for the dynamic linker
os.environ['PATH'] = dslib_path + ';' + os.environ['PATH'] os.environ["PATH"] = dslib_path + ";" + os.environ["PATH"]
import stt import stt
# rename for backwards compatibility # rename for backwards compatibility
from stt.impl import Version as version from stt.impl import Version as version
class Model(object): class Model(object):
""" """
Class holding a Coqui STT model Class holding a Coqui STT model
@ -29,13 +30,18 @@ class Model(object):
:param aModelPath: Path to model file to load :param aModelPath: Path to model file to load
:type aModelPath: str :type aModelPath: str
""" """
def __init__(self, model_path): def __init__(self, model_path):
# make sure the attribute is there if CreateModel fails # make sure the attribute is there if CreateModel fails
self._impl = None self._impl = None
status, impl = stt.impl.CreateModel(model_path) status, impl = stt.impl.CreateModel(model_path)
if status != 0: if status != 0:
raise RuntimeError("CreateModel failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) raise RuntimeError(
"CreateModel failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
self._impl = impl self._impl = impl
def __del__(self): def __del__(self):
@ -85,7 +91,11 @@ class Model(object):
""" """
status = stt.impl.EnableExternalScorer(self._impl, scorer_path) status = stt.impl.EnableExternalScorer(self._impl, scorer_path)
if status != 0: if status != 0:
raise RuntimeError("EnableExternalScorer failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) raise RuntimeError(
"EnableExternalScorer failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def disableExternalScorer(self): def disableExternalScorer(self):
""" """
@ -111,7 +121,11 @@ class Model(object):
""" """
status = stt.impl.AddHotWord(self._impl, word, boost) status = stt.impl.AddHotWord(self._impl, word, boost)
if status != 0: if status != 0:
raise RuntimeError("AddHotWord failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) raise RuntimeError(
"AddHotWord failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def eraseHotWord(self, word): def eraseHotWord(self, word):
""" """
@ -124,7 +138,11 @@ class Model(object):
""" """
status = stt.impl.EraseHotWord(self._impl, word) status = stt.impl.EraseHotWord(self._impl, word)
if status != 0: if status != 0:
raise RuntimeError("EraseHotWord failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) raise RuntimeError(
"EraseHotWord failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def clearHotWords(self): def clearHotWords(self):
""" """
@ -134,7 +152,11 @@ class Model(object):
""" """
status = stt.impl.ClearHotWords(self._impl) status = stt.impl.ClearHotWords(self._impl)
if status != 0: if status != 0:
raise RuntimeError("ClearHotWords failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) raise RuntimeError(
"ClearHotWords failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def setScorerAlphaBeta(self, alpha, beta): def setScorerAlphaBeta(self, alpha, beta):
""" """
@ -190,7 +212,11 @@ class Model(object):
""" """
status, ctx = stt.impl.CreateStream(self._impl) status, ctx = stt.impl.CreateStream(self._impl)
if status != 0: if status != 0:
raise RuntimeError("CreateStream failed with '{}' (0x{:X})".format(stt.impl.ErrorCodeToErrorMessage(status),status)) raise RuntimeError(
"CreateStream failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
return Stream(ctx) return Stream(ctx)
@ -199,6 +225,7 @@ class Stream(object):
Class wrapping a stt stream. The constructor cannot be called directly. Class wrapping a stt stream. The constructor cannot be called directly.
Use :func:`Model.createStream()` Use :func:`Model.createStream()`
""" """
def __init__(self, native_stream): def __init__(self, native_stream):
self._impl = native_stream self._impl = native_stream
@ -216,7 +243,9 @@ class Stream(object):
:throws: RuntimeError if the stream object is not valid :throws: RuntimeError if the stream object is not valid
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to feed an already finished stream?") raise RuntimeError(
"Stream object is not valid. Trying to feed an already finished stream?"
)
stt.impl.FeedAudioContent(self._impl, audio_buffer) stt.impl.FeedAudioContent(self._impl, audio_buffer)
def intermediateDecode(self): def intermediateDecode(self):
@ -229,7 +258,9 @@ class Stream(object):
:throws: RuntimeError if the stream object is not valid :throws: RuntimeError if the stream object is not valid
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") raise RuntimeError(
"Stream object is not valid. Trying to decode an already finished stream?"
)
return stt.impl.IntermediateDecode(self._impl) return stt.impl.IntermediateDecode(self._impl)
def intermediateDecodeWithMetadata(self, num_results=1): def intermediateDecodeWithMetadata(self, num_results=1):
@ -245,7 +276,9 @@ class Stream(object):
:throws: RuntimeError if the stream object is not valid :throws: RuntimeError if the stream object is not valid
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") raise RuntimeError(
"Stream object is not valid. Trying to decode an already finished stream?"
)
return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results) return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self): def finishStream(self):
@ -260,7 +293,9 @@ class Stream(object):
:throws: RuntimeError if the stream object is not valid :throws: RuntimeError if the stream object is not valid
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") raise RuntimeError(
"Stream object is not valid. Trying to finish an already finished stream?"
)
result = stt.impl.FinishStream(self._impl) result = stt.impl.FinishStream(self._impl)
self._impl = None self._impl = None
return result return result
@ -281,7 +316,9 @@ class Stream(object):
:throws: RuntimeError if the stream object is not valid :throws: RuntimeError if the stream object is not valid
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") raise RuntimeError(
"Stream object is not valid. Trying to finish an already finished stream?"
)
result = stt.impl.FinishStreamWithMetadata(self._impl, num_results) result = stt.impl.FinishStreamWithMetadata(self._impl, num_results)
self._impl = None self._impl = None
return result return result
@ -294,7 +331,9 @@ class Stream(object):
:throws: RuntimeError if the stream object is not valid :throws: RuntimeError if the stream object is not valid
""" """
if not self._impl: if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to free an already finished stream?") raise RuntimeError(
"Stream object is not valid. Trying to free an already finished stream?"
)
stt.impl.FreeStream(self._impl) stt.impl.FreeStream(self._impl)
self._impl = None self._impl = None
@ -311,13 +350,11 @@ class TokenMetadata(object):
The text for this token The text for this token
""" """
def timestep(self): def timestep(self):
""" """
Position of the token in units of 20ms Position of the token in units of 20ms
""" """
def start_time(self): def start_time(self):
""" """
Position of the token in seconds Position of the token in seconds
@ -328,6 +365,7 @@ class CandidateTranscript(object):
""" """
Stores the entire CTC output as an array of character metadata objects Stores the entire CTC output as an array of character metadata objects
""" """
def tokens(self): def tokens(self):
""" """
List of tokens List of tokens
@ -336,7 +374,6 @@ class CandidateTranscript(object):
:type: list :type: list
""" """
def confidence(self): def confidence(self):
""" """
Approximated confidence value for this transcription. This is roughly the Approximated confidence value for this transcription. This is roughly the

View File

@ -3,16 +3,16 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import numpy as np import json
import shlex import shlex
import subprocess import subprocess
import sys import sys
import wave import wave
import json
from stt import Model, version
from timeit import default_timer as timer from timeit import default_timer as timer
import numpy as np
from stt import Model, version
try: try:
from shhlex import quote from shhlex import quote
except ImportError: except ImportError:
@ -20,19 +20,26 @@ except ImportError:
def convert_samplerate(audio_path, desired_sample_rate): def convert_samplerate(audio_path, desired_sample_rate):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) sox_cmd = "sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - ".format(
quote(audio_path), desired_sample_rate
)
try: try:
output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE) output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr)) raise RuntimeError("SoX returned non-zero status: {}".format(e.stderr))
except OSError as e: except OSError as e:
raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(desired_sample_rate, e.strerror)) raise OSError(
e.errno,
"SoX not found, use {}hz files or install it: {}".format(
desired_sample_rate, e.strerror
),
)
return desired_sample_rate, np.frombuffer(output, np.int16) return desired_sample_rate, np.frombuffer(output, np.int16)
def metadata_to_string(metadata): def metadata_to_string(metadata):
return ''.join(token.text for token in metadata.tokens) return "".join(token.text for token in metadata.tokens)
def words_from_candidate_transcript(metadata): def words_from_candidate_transcript(metadata):
@ -70,56 +77,78 @@ def words_from_candidate_transcript(metadata):
def metadata_json_output(metadata): def metadata_json_output(metadata):
json_result = dict() json_result = dict()
json_result["transcripts"] = [{ json_result["transcripts"] = [
"confidence": transcript.confidence, {
"words": words_from_candidate_transcript(transcript), "confidence": transcript.confidence,
} for transcript in metadata.transcripts] "words": words_from_candidate_transcript(transcript),
}
for transcript in metadata.transcripts
]
return json.dumps(json_result, indent=2) return json.dumps(json_result, indent=2)
class VersionAction(argparse.Action): class VersionAction(argparse.Action):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(VersionAction, self).__init__(nargs=0, *args, **kwargs) super(VersionAction, self).__init__(nargs=0, *args, **kwargs)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
print('Coqui STT ', version()) print("Coqui STT ", version())
exit(0) exit(0)
def main(): def main():
parser = argparse.ArgumentParser(description='Running Coqui STT inference.') parser = argparse.ArgumentParser(description="Running Coqui STT inference.")
parser.add_argument('--model', required=True, parser.add_argument(
help='Path to the model (protocol buffer binary file)') "--model", required=True, help="Path to the model (protocol buffer binary file)"
parser.add_argument('--scorer', required=False, )
help='Path to the external scorer file') parser.add_argument(
parser.add_argument('--audio', required=True, "--scorer", required=False, help="Path to the external scorer file"
help='Path to the audio file to run (WAV format)') )
parser.add_argument('--beam_width', type=int, parser.add_argument(
help='Beam width for the CTC decoder') "--audio", required=True, help="Path to the audio file to run (WAV format)"
parser.add_argument('--lm_alpha', type=float, )
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.') parser.add_argument("--beam_width", type=int, help="Beam width for the CTC decoder")
parser.add_argument('--lm_beta', type=float, parser.add_argument(
help='Word insertion bonus (lm_beta). If not specified, use default from the scorer package.') "--lm_alpha",
parser.add_argument('--version', action=VersionAction, type=float,
help='Print version and exits') help="Language model weight (lm_alpha). If not specified, use default from the scorer package.",
parser.add_argument('--extended', required=False, action='store_true', )
help='Output string from extended metadata') parser.add_argument(
parser.add_argument('--json', required=False, action='store_true', "--lm_beta",
help='Output json from metadata with timestamp of each word') type=float,
parser.add_argument('--candidate_transcripts', type=int, default=3, help="Word insertion bonus (lm_beta). If not specified, use default from the scorer package.",
help='Number of candidate transcripts to include in JSON output') )
parser.add_argument('--hot_words', type=str, parser.add_argument(
help='Hot-words and their boosts.') "--version", action=VersionAction, help="Print version and exits"
)
parser.add_argument(
"--extended",
required=False,
action="store_true",
help="Output string from extended metadata",
)
parser.add_argument(
"--json",
required=False,
action="store_true",
help="Output json from metadata with timestamp of each word",
)
parser.add_argument(
"--candidate_transcripts",
type=int,
default=3,
help="Number of candidate transcripts to include in JSON output",
)
parser.add_argument("--hot_words", type=str, help="Hot-words and their boosts.")
args = parser.parse_args() args = parser.parse_args()
print('Loading model from file {}'.format(args.model), file=sys.stderr) print("Loading model from file {}".format(args.model), file=sys.stderr)
model_load_start = timer() model_load_start = timer()
# sphinx-doc: python_ref_model_start # sphinx-doc: python_ref_model_start
ds = Model(args.model) ds = Model(args.model)
# sphinx-doc: python_ref_model_stop # sphinx-doc: python_ref_model_stop
model_load_end = timer() - model_load_start model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) print("Loaded model in {:.3}s.".format(model_load_end), file=sys.stderr)
if args.beam_width: if args.beam_width:
ds.setBeamWidth(args.beam_width) ds.setBeamWidth(args.beam_width)
@ -127,44 +156,55 @@ def main():
desired_sample_rate = ds.sampleRate() desired_sample_rate = ds.sampleRate()
if args.scorer: if args.scorer:
print('Loading scorer from files {}'.format(args.scorer), file=sys.stderr) print("Loading scorer from files {}".format(args.scorer), file=sys.stderr)
scorer_load_start = timer() scorer_load_start = timer()
ds.enableExternalScorer(args.scorer) ds.enableExternalScorer(args.scorer)
scorer_load_end = timer() - scorer_load_start scorer_load_end = timer() - scorer_load_start
print('Loaded scorer in {:.3}s.'.format(scorer_load_end), file=sys.stderr) print("Loaded scorer in {:.3}s.".format(scorer_load_end), file=sys.stderr)
if args.lm_alpha and args.lm_beta: if args.lm_alpha and args.lm_beta:
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta) ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
if args.hot_words: if args.hot_words:
print('Adding hot-words', file=sys.stderr) print("Adding hot-words", file=sys.stderr)
for word_boost in args.hot_words.split(','): for word_boost in args.hot_words.split(","):
word,boost = word_boost.split(':') word, boost = word_boost.split(":")
ds.addHotWord(word,float(boost)) ds.addHotWord(word, float(boost))
fin = wave.open(args.audio, 'rb') fin = wave.open(args.audio, "rb")
fs_orig = fin.getframerate() fs_orig = fin.getframerate()
if fs_orig != desired_sample_rate: if fs_orig != desired_sample_rate:
print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs_orig, desired_sample_rate), file=sys.stderr) print(
"Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.".format(
fs_orig, desired_sample_rate
),
file=sys.stderr,
)
fs_new, audio = convert_samplerate(args.audio, desired_sample_rate) fs_new, audio = convert_samplerate(args.audio, desired_sample_rate)
else: else:
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/fs_orig) audio_length = fin.getnframes() * (1 / fs_orig)
fin.close() fin.close()
print('Running inference.', file=sys.stderr) print("Running inference.", file=sys.stderr)
inference_start = timer() inference_start = timer()
# sphinx-doc: python_ref_inference_start # sphinx-doc: python_ref_inference_start
if args.extended: if args.extended:
print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0])) print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0]))
elif args.json: elif args.json:
print(metadata_json_output(ds.sttWithMetadata(audio, args.candidate_transcripts))) print(
metadata_json_output(ds.sttWithMetadata(audio, args.candidate_transcripts))
)
else: else:
print(ds.stt(audio)) print(ds.stt(audio))
# sphinx-doc: python_ref_inference_stop # sphinx-doc: python_ref_inference_stop
inference_end = timer() - inference_start inference_end = timer() - inference_start
print('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length), file=sys.stderr) print(
"Inference took %0.3fs for %0.3fs audio file." % (inference_end, audio_length),
file=sys.stderr,
)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,107 +1,122 @@
#! /usr/bin/env python #! /usr/bin/env python
from setuptools import setup, Extension
from distutils.command.build import build
import os import os
import subprocess import subprocess
import sys import sys
from distutils.command.build import build
from setuptools import Extension, setup
def main(): def main():
try: try:
import numpy import numpy
try: try:
numpy_include = numpy.get_include() numpy_include = numpy.get_include()
except AttributeError: except AttributeError:
numpy_include = numpy.get_numpy_include() numpy_include = numpy.get_numpy_include()
except ImportError: except ImportError:
numpy_include = '' numpy_include = ""
assert 'NUMPY_INCLUDE' in os.environ assert "NUMPY_INCLUDE" in os.environ
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
numpy_include = os.getenv('NUMPY_INCLUDE', numpy_include) numpy_include = os.getenv("NUMPY_INCLUDE", numpy_include)
numpy_min_ver = os.getenv('NUMPY_DEP_VERSION', '') numpy_min_ver = os.getenv("NUMPY_DEP_VERSION", "")
project_name = 'STT' project_name = "STT"
if '--project_name' in sys.argv: if "--project_name" in sys.argv:
project_name_idx = sys.argv.index('--project_name') project_name_idx = sys.argv.index("--project_name")
project_name = sys.argv[project_name_idx + 1] project_name = sys.argv[project_name_idx + 1]
sys.argv.remove('--project_name') sys.argv.remove("--project_name")
sys.argv.pop(project_name_idx) sys.argv.pop(project_name_idx)
with open('../../training/coqui_stt_training/VERSION', 'r') as ver: with open("../../training/coqui_stt_training/VERSION", "r") as ver:
project_version = ver.read().strip() project_version = ver.read().strip()
class BuildExtFirst(build): class BuildExtFirst(build):
sub_commands = [('build_ext', build.has_ext_modules), sub_commands = [
('build_py', build.has_pure_modules), ("build_ext", build.has_ext_modules),
('build_clib', build.has_c_libraries), ("build_py", build.has_pure_modules),
('build_scripts', build.has_scripts)] ("build_clib", build.has_c_libraries),
("build_scripts", build.has_scripts),
]
# Properly pass arguments for linking, setuptools will perform some checks # Properly pass arguments for linking, setuptools will perform some checks
def lib_dirs_split(a): def lib_dirs_split(a):
if os.name == 'posix': if os.name == "posix":
return a.split('-L')[1:] return a.split("-L")[1:]
if os.name == 'nt': if os.name == "nt":
return [] return []
raise AssertionError('os.name == java not expected') raise AssertionError("os.name == java not expected")
def libs_split(a): def libs_split(a):
if os.name == 'posix': if os.name == "posix":
return a.split('-l')[1:] return a.split("-l")[1:]
if os.name == 'nt': if os.name == "nt":
return a.split('.lib')[0:1] return a.split(".lib")[0:1]
raise AssertionError('os.name == java not expected') raise AssertionError("os.name == java not expected")
ds_ext = Extension(name='stt._impl', ds_ext = Extension(
sources=['impl.i'], name="stt._impl",
include_dirs=[numpy_include, '../'], sources=["impl.i"],
library_dirs=list(map(lambda x: x.strip(), lib_dirs_split(os.getenv('MODEL_LDFLAGS', '')))), include_dirs=[numpy_include, "../"],
libraries=list(map(lambda x: x.strip(), libs_split(os.getenv('MODEL_LIBS', '')))), library_dirs=list(
swig_opts=['-c++', '-keyword']) map(lambda x: x.strip(), lib_dirs_split(os.getenv("MODEL_LDFLAGS", "")))
),
libraries=list(
map(lambda x: x.strip(), libs_split(os.getenv("MODEL_LIBS", "")))
),
swig_opts=["-c++", "-keyword"],
)
setup(name=project_name, setup(
description='A library for doing speech recognition using a Coqui STT model', name=project_name,
long_description=read('README.rst'), description="A library for doing speech recognition using a Coqui STT model",
long_description_content_type='text/x-rst; charset=UTF-8', long_description=read("README.rst"),
author='Coqui GmbH', long_description_content_type="text/x-rst; charset=UTF-8",
version=project_version, author="Coqui GmbH",
package_dir={'stt': '.'}, version=project_version,
cmdclass={'build': BuildExtFirst}, package_dir={"stt": "."},
license='MPL-2.0', cmdclass={"build": BuildExtFirst},
url='https://github.com/coqui-ai/STT', license="MPL-2.0",
project_urls={ url="https://github.com/coqui-ai/STT",
'Documentation': 'https://stt.readthedocs.io', project_urls={
'Tracker': 'https://github.com/coqui-ai/STT/issues', "Documentation": "https://stt.readthedocs.io",
'Repository': 'https://github.com/coqui-ai/STT/tree/v{}'.format(project_version), "Tracker": "https://github.com/coqui-ai/STT/issues",
'Discussions': 'https://github.com/coqui-ai/STT/discussions', "Repository": "https://github.com/coqui-ai/STT/tree/v{}".format(
}, project_version
ext_modules=[ds_ext], ),
py_modules=['stt', 'stt.client', 'stt.impl'], "Discussions": "https://github.com/coqui-ai/STT/discussions",
entry_points={'console_scripts':['stt=stt.client:main']}, },
install_requires=['numpy%s' % numpy_min_ver], ext_modules=[ds_ext],
include_package_data=True, py_modules=["stt", "stt.client", "stt.impl"],
classifiers=[ entry_points={"console_scripts": ["stt=stt.client:main"]},
'Development Status :: 3 - Alpha', install_requires=["numpy%s" % numpy_min_ver],
'Environment :: Console', include_package_data=True,
'Intended Audience :: Developers', classifiers=[
'Intended Audience :: Science/Research', "Development Status :: 3 - Alpha",
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', "Environment :: Console",
'Programming Language :: Python :: 2.7', "Intended Audience :: Developers",
'Programming Language :: Python :: 3.4', "Intended Audience :: Science/Research",
'Programming Language :: Python :: 3.5', "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 2.7",
'Topic :: Multimedia :: Sound/Audio :: Speech', "Programming Language :: Python :: 3.4",
'Topic :: Scientific/Engineering :: Human Machine Interfaces', "Programming Language :: Python :: 3.5",
'Topic :: Scientific/Engineering', "Programming Language :: Python :: 3.6",
'Topic :: Utilities', "Topic :: Multimedia :: Sound/Audio :: Speech",
]) "Topic :: Scientific/Engineering :: Human Machine Interfaces",
"Topic :: Scientific/Engineering",
"Topic :: Utilities",
],
)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -9,5 +9,3 @@
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
// In this header, you should import all the public headers of your framework using statements like #import <stt_ios/PublicHeader.h> // In this header, you should import all the public headers of your framework using statements like #import <stt_ios/PublicHeader.h>

View File

@ -62,4 +62,3 @@ class SceneDelegate: UIResponder, UIWindowSceneDelegate {
} }

View File

@ -3,22 +3,26 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import numpy as np
import wave import wave
import numpy as np
from stt import Model from stt import Model
def main(): def main():
parser = argparse.ArgumentParser(description='Running STT inference.') parser = argparse.ArgumentParser(description="Running STT inference.")
parser.add_argument('--model', required=True, parser.add_argument(
help='Path to the model (protocol buffer binary file)') "--model", required=True, help="Path to the model (protocol buffer binary file)"
parser.add_argument('--scorer', nargs='?', )
help='Path to the external scorer file') parser.add_argument("--scorer", nargs="?", help="Path to the external scorer file")
parser.add_argument('--audio1', required=True, parser.add_argument(
help='First audio file to use in interleaved streams') "--audio1", required=True, help="First audio file to use in interleaved streams"
parser.add_argument('--audio2', required=True, )
help='Second audio file to use in interleaved streams') parser.add_argument(
"--audio2",
required=True,
help="Second audio file to use in interleaved streams",
)
args = parser.parse_args() args = parser.parse_args()
ds = Model(args.model) ds = Model(args.model)
@ -26,12 +30,12 @@ def main():
if args.scorer: if args.scorer:
ds.enableExternalScorer(args.scorer) ds.enableExternalScorer(args.scorer)
fin = wave.open(args.audio1, 'rb') fin = wave.open(args.audio1, "rb")
fs1 = fin.getframerate() fs1 = fin.getframerate()
audio1 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) audio1 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close() fin.close()
fin = wave.open(args.audio2, 'rb') fin = wave.open(args.audio2, "rb")
fs2 = fin.getframerate() fs2 = fin.getframerate()
audio2 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) audio2 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
fin.close() fin.close()
@ -49,5 +53,6 @@ def main():
print(stream1.finishStream()) print(stream1.finishStream())
print(stream2.finishStream()) print(stream2.finishStream())
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -8,77 +8,74 @@ from setuptools import find_packages, setup
def main(): def main():
version_file = Path(__file__).parent / 'VERSION' version_file = Path(__file__).parent / "VERSION"
with open(str(version_file)) as fin: with open(str(version_file)) as fin:
version = fin.read().strip() version = fin.read().strip()
install_requires_base = [ install_requires_base = [
'absl-py', "absl-py",
'attrdict', "attrdict",
'bs4', "bs4",
'numpy', "numpy",
'optuna', "optuna",
'opuslib == 2.0.0', "opuslib == 2.0.0",
'pandas', "pandas",
'progressbar2', "progressbar2",
'pyogg >= 0.6.14a1', "pyogg >= 0.6.14a1",
'pyxdg', "pyxdg",
'resampy >= 0.2.2', "resampy >= 0.2.2",
'requests', "requests",
'semver', "semver",
'six', "six",
'sox', "sox",
'soundfile', "soundfile",
] ]
decoder_pypi_dep = [ decoder_pypi_dep = ["coqui_stt_ctcdecoder == {}".format(version)]
'coqui_stt_ctcdecoder == {}'.format(version)
]
tensorflow_pypi_dep = [ tensorflow_pypi_dep = ["tensorflow == 1.15.4"]
'tensorflow == 1.15.4'
]
if os.environ.get('DS_NODECODER', ''): if os.environ.get("DS_NODECODER", ""):
install_requires = install_requires_base install_requires = install_requires_base
else: else:
install_requires = install_requires_base + decoder_pypi_dep install_requires = install_requires_base + decoder_pypi_dep
if os.environ.get('DS_NOTENSORFLOW', ''): if os.environ.get("DS_NOTENSORFLOW", ""):
install_requires = install_requires install_requires = install_requires
else: else:
install_requires = install_requires + tensorflow_pypi_dep install_requires = install_requires + tensorflow_pypi_dep
setup( setup(
name='coqui_stt_training', name="coqui_stt_training",
version=version, version=version,
description='Training code for Coqui STT', description="Training code for Coqui STT",
url='https://github.com/coqui-ai/STT', url="https://github.com/coqui-ai/STT",
author='Coqui STT authors', author="Coqui STT authors",
license='MPL-2.0', license="MPL-2.0",
# Classifiers help users find your project by categorizing it. # Classifiers help users find your project by categorizing it.
# #
# For a list of valid classifiers, see https://pypi.org/classifiers/ # For a list of valid classifiers, see https://pypi.org/classifiers/
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha', "Development Status :: 3 - Alpha",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'Topic :: Multimedia :: Sound/Audio :: Speech', "Topic :: Multimedia :: Sound/Audio :: Speech",
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
], ],
package_dir={'': 'training'}, package_dir={"": "training"},
packages=find_packages(where='training'), packages=find_packages(where="training"),
python_requires='>=3.5, <4', python_requires=">=3.5, <4",
install_requires=install_requires, install_requires=install_requires,
# If there are data files included in your packages that need to be # If there are data files included in your packages that need to be
# installed, specify them here. # installed, specify them here.
package_data={ package_data={
'coqui_stt_training': [ "coqui_stt_training": [
'VERSION', "VERSION",
'GRAPH_VERSION', "GRAPH_VERSION",
], ],
}, },
) )
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import functools import functools
import pandas
from coqui_stt_training.util.helpers import secs_to_hours
from pathlib import Path from pathlib import Path
import pandas
from coqui_stt_training.util.helpers import secs_to_hours
def read_csvs(csv_files): def read_csvs(csv_files):
# Relative paths are relative to CSV location # Relative paths are relative to CSV location
@ -17,32 +17,59 @@ def read_csvs(csv_files):
sets = [] sets = []
for csv in csv_files: for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False) file = pandas.read_csv(csv, encoding="utf-8", na_filter=False)
file['wav_filename'] = file['wav_filename'].apply(functools.partial(absolutify, csv)) file["wav_filename"] = file["wav_filename"].apply(
functools.partial(absolutify, csv)
)
sets.append(file) sets.append(file)
# Concat all sets, drop any extra columns, re-index the final result as 0..N # Concat all sets, drop any extra columns, re-index the final result as 0..N
return pandas.concat(sets, join='inner', ignore_index=True) return pandas.concat(sets, join="inner", ignore_index=True)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True) parser.add_argument(
parser.add_argument("--sample-rate", type=int, default=16000, required=False, help="Audio sample rate") "-csv",
parser.add_argument("--channels", type=int, default=1, required=False, help="Audio channels") "--csv-files",
parser.add_argument("--bits-per-sample", type=int, default=16, required=False, help="Audio bits per sample") help="Str. Filenames as a comma separated list",
required=True,
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
required=False,
help="Audio sample rate",
)
parser.add_argument(
"--channels", type=int, default=1, required=False, help="Audio channels"
)
parser.add_argument(
"--bits-per-sample",
type=int,
default=16,
required=False,
help="Audio bits per sample",
)
args = parser.parse_args() args = parser.parse_args()
in_files = [Path(i).absolute() for i in args.csv_files.split(",")] in_files = [Path(i).absolute() for i in args.csv_files.split(",")]
csv_dataframe = read_csvs(in_files) csv_dataframe = read_csvs(in_files)
total_bytes = csv_dataframe['wav_filesize'].sum() total_bytes = csv_dataframe["wav_filesize"].sum()
total_files = len(csv_dataframe) total_files = len(csv_dataframe)
total_seconds = ((csv_dataframe['wav_filesize'] - 44) / args.sample_rate / args.channels / (args.bits_per_sample // 8)).sum() total_seconds = (
(csv_dataframe["wav_filesize"] - 44)
/ args.sample_rate
/ args.channels
/ (args.bits_per_sample // 8)
).sum()
print('Total bytes:', total_bytes) print("Total bytes:", total_bytes)
print('Total files:', total_files) print("Total files:", total_files)
print('Total time:', secs_to_hours(total_seconds)) print("Total time:", secs_to_hours(total_seconds))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,4 +1,3 @@
a a
b b
c c

View File

@ -1,42 +1,49 @@
import unittest import unittest
from argparse import Namespace from argparse import Namespace
from coqui_stt_training.util.importers import validate_label_eng, get_validate_label
from pathlib import Path from pathlib import Path
from coqui_stt_training.util.importers import get_validate_label, validate_label_eng
def from_here(path): def from_here(path):
here = Path(__file__) here = Path(__file__)
return here.parent / path return here.parent / path
class TestValidateLabelEng(unittest.TestCase): class TestValidateLabelEng(unittest.TestCase):
def test_numbers(self): def test_numbers(self):
label = validate_label_eng("this is a 1 2 3 test") label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None) self.assertEqual(label, None)
class TestGetValidateLabel(unittest.TestCase):
class TestGetValidateLabel(unittest.TestCase):
def test_no_validate_label_locale(self): def test_no_validate_label_locale(self):
f = get_validate_label(Namespace()) f = get_validate_label(Namespace())
self.assertEqual(f('toto'), 'toto') self.assertEqual(f("toto"), "toto")
self.assertEqual(f('toto1234'), None) self.assertEqual(f("toto1234"), None)
self.assertEqual(f('toto1234[{[{[]'), None) self.assertEqual(f("toto1234[{[{[]"), None)
def test_validate_label_locale_default(self): def test_validate_label_locale_default(self):
f = get_validate_label(Namespace(validate_label_locale=None)) f = get_validate_label(Namespace(validate_label_locale=None))
self.assertEqual(f('toto'), 'toto') self.assertEqual(f("toto"), "toto")
self.assertEqual(f('toto1234'), None) self.assertEqual(f("toto1234"), None)
self.assertEqual(f('toto1234[{[{[]'), None) self.assertEqual(f("toto1234[{[{[]"), None)
def test_get_validate_label_missing(self): def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_ger.py')) args = Namespace(
validate_label_locale=from_here("test_data/validate_locale_ger.py")
)
f = get_validate_label(args) f = get_validate_label(args)
self.assertEqual(f, None) self.assertEqual(f, None)
def test_get_validate_label(self): def test_get_validate_label(self):
args = Namespace(validate_label_locale=from_here('test_data/validate_locale_fra.py')) args = Namespace(
validate_label_locale=from_here("test_data/validate_locale_fra.py")
)
f = get_validate_label(args) f = get_validate_label(args)
l = f('toto') l = f("toto")
self.assertEqual(l, 'toto') self.assertEqual(l, "toto")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,13 +1,13 @@
import unittest
import os import os
import unittest
from coqui_stt_ctcdecoder import Alphabet from coqui_stt_ctcdecoder import Alphabet
class TestAlphabetParsing(unittest.TestCase):
class TestAlphabetParsing(unittest.TestCase):
def _ending_tester(self, file, expected): def _ending_tester(self, file, expected):
alphabet = Alphabet(os.path.join(os.path.dirname(__file__), 'test_data', file)) alphabet = Alphabet(os.path.join(os.path.dirname(__file__), "test_data", file))
label = '' label = ""
label_id = -1 label_id = -1
for expected_label, expected_label_id in expected: for expected_label, expected_label_id in expected:
try: try:
@ -22,13 +22,14 @@ class TestAlphabetParsing(unittest.TestCase):
self.assertEqual(label, expected_label) self.assertEqual(label, expected_label)
def test_macos_ending(self): def test_macos_ending(self):
self._ending_tester('alphabet_macos.txt', [('a', 0), ('b', 1), ('c', 2)]) self._ending_tester("alphabet_macos.txt", [("a", 0), ("b", 1), ("c", 2)])
def test_unix_ending(self): def test_unix_ending(self):
self._ending_tester('alphabet_unix.txt', [('a', 0), ('b', 1), ('c', 2)]) self._ending_tester("alphabet_unix.txt", [("a", 0), ("b", 1), ("c", 2)])
def test_windows_ending(self): def test_windows_ending(self):
self._ending_tester('alphabet_windows.txt', [('a', 0), ('b', 1), ('c', 2)]) self._ending_tester("alphabet_windows.txt", [("a", 0), ("b", 1), ("c", 2)])
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,27 +1,32 @@
import unittest import unittest
import numpy as np import numpy as np
from coqui_stt_training.util.helpers import (
ValueRange,
get_value_range,
pick_value_from_range,
tf_pick_value_from_range,
)
import tensorflow as tf import tensorflow as tf
from coqui_stt_training.util.helpers import ValueRange, get_value_range, pick_value_from_range, tf_pick_value_from_range
class TestValueRange(unittest.TestCase): class TestValueRange(unittest.TestCase):
def _ending_tester(self, value, value_type, expected): def _ending_tester(self, value, value_type, expected):
result = get_value_range(value, value_type) result = get_value_range(value, value_type)
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_int_str_scalar(self): def test_int_str_scalar(self):
self._ending_tester('1', int, ValueRange(1, 1, 0)) self._ending_tester("1", int, ValueRange(1, 1, 0))
def test_int_str_scalar_radius(self): def test_int_str_scalar_radius(self):
self._ending_tester('1~3', int, ValueRange(1, 1, 3)) self._ending_tester("1~3", int, ValueRange(1, 1, 3))
def test_int_str_range(self): def test_int_str_range(self):
self._ending_tester('1:2', int, ValueRange(1, 2, 0)) self._ending_tester("1:2", int, ValueRange(1, 2, 0))
def test_int_str_range_radius(self): def test_int_str_range_radius(self):
self._ending_tester('1:2~3', int, ValueRange(1, 2, 3)) self._ending_tester("1:2~3", int, ValueRange(1, 2, 3))
def test_int_scalar(self): def test_int_scalar(self):
self._ending_tester(1, int, ValueRange(1, 1, 0)) self._ending_tester(1, int, ValueRange(1, 1, 0))
@ -33,16 +38,16 @@ class TestValueRange(unittest.TestCase):
self._ending_tester((1, 2, 3), int, ValueRange(1, 2, 3)) self._ending_tester((1, 2, 3), int, ValueRange(1, 2, 3))
def test_float_str_scalar(self): def test_float_str_scalar(self):
self._ending_tester('1.0', float, ValueRange(1.0, 1.0, 0.0)) self._ending_tester("1.0", float, ValueRange(1.0, 1.0, 0.0))
def test_float_str_scalar_radius(self): def test_float_str_scalar_radius(self):
self._ending_tester('1.0~3.0', float, ValueRange(1.0, 1.0, 3.0)) self._ending_tester("1.0~3.0", float, ValueRange(1.0, 1.0, 3.0))
def test_float_str_range(self): def test_float_str_range(self):
self._ending_tester('1.0:2.0', float, ValueRange(1.0, 2.0, 0.0)) self._ending_tester("1.0:2.0", float, ValueRange(1.0, 2.0, 0.0))
def test_float_str_range_radius(self): def test_float_str_range_radius(self):
self._ending_tester('1.0:2.0~3.0', float, ValueRange(1.0, 2.0, 3.0)) self._ending_tester("1.0:2.0~3.0", float, ValueRange(1.0, 2.0, 3.0))
def test_float_scalar(self): def test_float_scalar(self):
self._ending_tester(1.0, float, ValueRange(1.0, 1.0, 0.0)) self._ending_tester(1.0, float, ValueRange(1.0, 1.0, 0.0))
@ -61,7 +66,7 @@ class TestPickValueFromFixedRange(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestPickValueFromFixedRange, self).__init__(*args, **kwargs) super(TestPickValueFromFixedRange, self).__init__(*args, **kwargs)
self.session = tf.Session() self.session = tf.Session()
self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock') self.clock_ph = tf.placeholder(dtype=tf.float64, name="clock")
def _ending_tester(self, value_range, clock, expected): def _ending_tester(self, value_range, clock, expected):
with tf.Session() as session: with tf.Session() as session:
@ -71,7 +76,10 @@ class TestPickValueFromFixedRange(unittest.TestCase):
return session.run(tf_pick, feed_dict={self.clock_ph: c}) return session.run(tf_pick, feed_dict={self.clock_ph: c})
is_int = isinstance(value_range.start, int) is_int = isinstance(value_range.start, int)
for pick, int_type, float_type in [(pick_value_from_range, int, float), (run_pick, np.int32, np.float32)]: for pick, int_type, float_type in [
(pick_value_from_range, int, float),
(run_pick, np.int32, np.float32),
]:
result = pick(value_range, clock) result = pick(value_range, clock)
self.assertEqual(result, expected) self.assertEqual(result, expected)
self.assertTrue(isinstance(result, int_type if is_int else float_type)) self.assertTrue(isinstance(result, int_type if is_int else float_type))
@ -99,9 +107,11 @@ class TestPickValueFromRandomizedRange(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestPickValueFromRandomizedRange, self).__init__(*args, **kwargs) super(TestPickValueFromRandomizedRange, self).__init__(*args, **kwargs)
self.session = tf.Session() self.session = tf.Session()
self.clock_ph = tf.placeholder(dtype=tf.float64, name='clock') self.clock_ph = tf.placeholder(dtype=tf.float64, name="clock")
def _ending_tester(self, value_range, clock_min, clock_max, expected_min, expected_max): def _ending_tester(
self, value_range, clock_min, clock_max, expected_min, expected_max
):
with self.session as session: with self.session as session:
tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph) tf_pick = tf_pick_value_from_range(value_range, clock=self.clock_ph)
@ -109,12 +119,26 @@ class TestPickValueFromRandomizedRange(unittest.TestCase):
return session.run(tf_pick, feed_dict={self.clock_ph: c}) return session.run(tf_pick, feed_dict={self.clock_ph: c})
is_int = isinstance(value_range.start, int) is_int = isinstance(value_range.start, int)
clock_range = np.arange(clock_min, clock_max, (clock_max - clock_min) / 100.0) clock_range = np.arange(
for pick, int_type, float_type in [(pick_value_from_range, int, float), (run_pick, np.int32, np.float32)]: clock_min, clock_max, (clock_max - clock_min) / 100.0
)
for pick, int_type, float_type in [
(pick_value_from_range, int, float),
(run_pick, np.int32, np.float32),
]:
results = [pick(value_range, c) for c in clock_range] results = [pick(value_range, c) for c in clock_range]
self.assertGreater(len(set(results)), 80) self.assertGreater(len(set(results)), 80)
self.assertTrue(all(map(lambda x: expected_min <= x <= expected_max, results))) self.assertTrue(
self.assertTrue(all(map(lambda x: isinstance(x, int_type if is_int else float_type), results))) all(map(lambda x: expected_min <= x <= expected_max, results))
)
self.assertTrue(
all(
map(
lambda x: isinstance(x, int_type if is_int else float_type),
results,
)
)
)
def test_int_0(self): def test_int_0(self):
self._ending_tester(ValueRange(10000, 30000, 10000), 0.0, 0.1, 0, 22000) self._ending_tester(ValueRange(10000, 30000, 10000), 0.0, 0.1, 0, 22000)
@ -126,14 +150,20 @@ class TestPickValueFromRandomizedRange(unittest.TestCase):
self._ending_tester(ValueRange(10000, 30000, 10000), 0.8, 1.0, 16000, 40000) self._ending_tester(ValueRange(10000, 30000, 10000), 0.8, 1.0, 16000, 40000)
def test_float_0(self): def test_float_0(self):
self._ending_tester(ValueRange(10000.0, 30000.0, 10000.0), 0.0, 0.1, 0.0, 22000.0) self._ending_tester(
ValueRange(10000.0, 30000.0, 10000.0), 0.0, 0.1, 0.0, 22000.0
)
def test_float_half(self): def test_float_half(self):
self._ending_tester(ValueRange(10000.0, 30000.0, 10000.0), 0.4, 0.6, 8000.0, 32000.0) self._ending_tester(
ValueRange(10000.0, 30000.0, 10000.0), 0.4, 0.6, 8000.0, 32000.0
)
def test_float_1(self): def test_float_1(self):
self._ending_tester(ValueRange(10000.0, 30000.0, 10000.0), 0.8, 1.0, 16000.0, 40000.0) self._ending_tester(
ValueRange(10000.0, 30000.0, 10000.0), 0.8, 1.0, 16000.0, 40000.0
)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -2,11 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
if __name__ == '__main__': if __name__ == "__main__":
try: try:
from coqui_stt_training import train as ds_train from coqui_stt_training import train as ds_train
except ImportError: except ImportError:
print('Training package is not installed. See training documentation.') print("Training package is not installed. See training documentation.")
raise raise
ds_train.run_script() ds_train.run_script()

View File

@ -4,33 +4,36 @@ from __future__ import absolute_import, division, print_function
import json import json
import sys import sys
from multiprocessing import cpu_count from multiprocessing import cpu_count
import absl.app import absl.app
import progressbar import progressbar
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from six.moves import zip
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.augmentations import NormalizeSampleRate from .util.augmentations import NormalizeSampleRate
from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph_for_evaluation from .util.checkpoints import load_graph_for_evaluation
from .util.config import Config, initialize_globals
from .util.evaluate_tools import calculate_and_print_report, save_samples_json from .util.evaluate_tools import calculate_and_print_report, save_samples_json
from .util.feeding import create_dataset from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS from .util.flags import FLAGS, create_flags
from .util.helpers import check_ctcdecoder_version from .util.helpers import check_ctcdecoder_version
from .util.logging import create_progressbar, log_error, log_progress from .util.logging import create_progressbar, log_error, log_progress
check_ctcdecoder_version() check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet): def sparse_tensor_value_to_texts(value, alphabet):
r""" r"""
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
representing its values, converting tokens to strings using ``alphabet``. representing its values, converting tokens to strings using ``alphabet``.
""" """
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet) return sparse_tuple_to_texts(
(value.indices, value.values, value.dense_shape), alphabet
)
def sparse_tuple_to_texts(sp_tuple, alphabet): def sparse_tuple_to_texts(sp_tuple, alphabet):
@ -45,36 +48,42 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
def evaluate(test_csvs, create_model): def evaluate(test_csvs, create_model):
if FLAGS.scorer_path: if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, scorer = Scorer(
FLAGS.scorer_path, Config.alphabet) FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
)
else: else:
scorer = None scorer = None
test_sets = [create_dataset([csv], test_sets = [
batch_size=FLAGS.test_batch_size, create_dataset(
train_phase=False, [csv],
augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], batch_size=FLAGS.test_batch_size,
reverse=FLAGS.reverse_test, train_phase=False,
limit=FLAGS.limit_test) for csv in test_csvs] augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)],
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), reverse=FLAGS.reverse_test,
tfv1.data.get_output_shapes(test_sets[0]), limit=FLAGS.limit_test,
output_classes=tfv1.data.get_output_classes(test_sets[0])) )
for csv in test_csvs
]
iterator = tfv1.data.Iterator.from_structure(
tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]),
)
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
# One rate per layer # One rate per layer
no_dropout = [None] * 6 no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x, logits, _ = create_model(
seq_length=batch_x_len, batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
dropout=no_dropout) )
# Transpose to batch major and apply softmax for decoder # Transpose to batch major and apply softmax for decoder
transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2]))
loss = tfv1.nn.ctc_loss(labels=batch_y, loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len)
inputs=logits,
sequence_length=batch_x_len)
tfv1.train.get_or_create_global_step() tfv1.train.get_or_create_global_step()
@ -93,9 +102,11 @@ def evaluate(test_csvs, create_model):
predictions = [] predictions = []
ground_truths = [] ground_truths = []
bar = create_progressbar(prefix='Test epoch | ', bar = create_progressbar(
widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start() prefix="Test epoch | ",
log_progress('Test epoch...') widgets=["Steps: ", progressbar.Counter(), " | ", progressbar.Timer()],
).start()
log_progress("Test epoch...")
step_count = 0 step_count = 0
@ -105,17 +116,35 @@ def evaluate(test_csvs, create_model):
# First pass, compute losses and transposed logits for decoding # First pass, compute losses and transposed logits for decoding
while True: while True:
try: try:
batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ (
session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) batch_wav_filenames,
batch_logits,
batch_loss,
batch_lengths,
batch_transcripts,
) = session.run(
[batch_wav_filename, transposed, loss, batch_x_len, batch_y]
)
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
break break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, decoded = ctc_beam_search_decoder_batch(
num_processes=num_processes, scorer=scorer, batch_logits,
cutoff_prob=FLAGS.cutoff_prob, cutoff_top_n=FLAGS.cutoff_top_n) batch_lengths,
Config.alphabet,
FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer,
cutoff_prob=FLAGS.cutoff_prob,
cutoff_top_n=FLAGS.cutoff_top_n,
)
predictions.extend(d[0][1] for d in decoded) predictions.extend(d[0][1] for d in decoded)
ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) ground_truths.extend(
wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)
)
wav_filenames.extend(
wav_filename.decode("UTF-8") for wav_filename in batch_wav_filenames
)
losses.extend(batch_loss) losses.extend(batch_loss)
step_count += 1 step_count += 1
@ -124,12 +153,14 @@ def evaluate(test_csvs, create_model):
bar.finish() bar.finish()
# Print test summary # Print test summary
test_samples = calculate_and_print_report(wav_filenames, ground_truths, predictions, losses, dataset) test_samples = calculate_and_print_report(
wav_filenames, ground_truths, predictions, losses, dataset
)
return test_samples return test_samples
samples = [] samples = []
for csv, init_op in zip(test_csvs, test_init_ops): for csv, init_op in zip(test_csvs, test_init_ops):
print('Testing model on {}'.format(csv)) print("Testing model on {}".format(csv))
samples.extend(run_test(init_op, dataset=csv)) samples.extend(run_test(init_op, dataset=csv))
return samples return samples
@ -138,12 +169,17 @@ def main(_):
initialize_globals() initialize_globals()
if not FLAGS.test_files: if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via ' log_error(
'the --test_files flag.') "You need to specify what files to use for evaluation via "
"the --test_files flag."
)
sys.exit(1) sys.exit(1)
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel from .train import ( # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model) create_model,
)
samples = evaluate(FLAGS.test_files.split(","), create_model)
if FLAGS.test_output_file: if FLAGS.test_output_file:
save_samples_json(samples, FLAGS.test_output_file) save_samples_json(samples, FLAGS.test_output_file)
@ -153,5 +189,6 @@ def run_script():
create_flags() create_flags()
absl.app.run(main) absl.app.run(main)
if __name__ == '__main__':
if __name__ == "__main__":
run_script() run_script()

File diff suppressed because it is too large Load Diff

View File

@ -2,28 +2,29 @@ import collections
import ctypes import ctypes
import io import io
import math import math
import numpy as np
import os import os
import pyogg
import tempfile import tempfile
import wave import wave
from collections import namedtuple
import numpy as np
import pyogg
from .helpers import LimitingPool from .helpers import LimitingPool
from collections import namedtuple from .io import copy_remote, is_remote_path, open_remote, remove_remote
from .io import open_remote, remove_remote, copy_remote, is_remote_path
AudioFormat = namedtuple('AudioFormat', 'rate channels width') AudioFormat = namedtuple("AudioFormat", "rate channels width")
DEFAULT_RATE = 16000 DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1 DEFAULT_CHANNELS = 1
DEFAULT_WIDTH = 2 DEFAULT_WIDTH = 2
DEFAULT_FORMAT = AudioFormat(DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH) DEFAULT_FORMAT = AudioFormat(DEFAULT_RATE, DEFAULT_CHANNELS, DEFAULT_WIDTH)
AUDIO_TYPE_NP = 'application/vnd.mozilla.np' AUDIO_TYPE_NP = "application/vnd.mozilla.np"
AUDIO_TYPE_PCM = 'application/vnd.mozilla.pcm' AUDIO_TYPE_PCM = "application/vnd.mozilla.pcm"
AUDIO_TYPE_WAV = 'audio/wav' AUDIO_TYPE_WAV = "audio/wav"
AUDIO_TYPE_OPUS = 'application/vnd.mozilla.opus' AUDIO_TYPE_OPUS = "application/vnd.mozilla.opus"
AUDIO_TYPE_OGG_OPUS = 'application/vnd.deepspeech.ogg_opus' AUDIO_TYPE_OGG_OPUS = "application/vnd.deepspeech.ogg_opus"
SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, AUDIO_TYPE_OGG_OPUS] SERIALIZABLE_AUDIO_TYPES = [AUDIO_TYPE_WAV, AUDIO_TYPE_OPUS, AUDIO_TYPE_OGG_OPUS]
@ -49,6 +50,7 @@ class Sample:
duration : float duration : float
Audio duration of the sample in seconds Audio duration of the sample in seconds
""" """
def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None): def __init__(self, audio_type, raw_data, audio_format=None, sample_id=None):
""" """
Parameters Parameters
@ -74,20 +76,26 @@ class Sample:
self.audio_format = audio_format self.audio_format = audio_format
self.sample_id = sample_id self.sample_id = sample_id
if audio_type in SERIALIZABLE_AUDIO_TYPES: if audio_type in SERIALIZABLE_AUDIO_TYPES:
self.audio = raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data) self.audio = (
raw_data if isinstance(raw_data, io.BytesIO) else io.BytesIO(raw_data)
)
self.duration = read_duration(audio_type, self.audio) self.duration = read_duration(audio_type, self.audio)
if not self.audio_format: if not self.audio_format:
self.audio_format = read_format(audio_type, self.audio) self.audio_format = read_format(audio_type, self.audio)
else: else:
self.audio = raw_data self.audio = raw_data
if self.audio_format is None: if self.audio_format is None:
raise ValueError('For audio type "{}" parameter "audio_format" is mandatory'.format(self.audio_type)) raise ValueError(
'For audio type "{}" parameter "audio_format" is mandatory'.format(
self.audio_type
)
)
if audio_type == AUDIO_TYPE_PCM: if audio_type == AUDIO_TYPE_PCM:
self.duration = get_pcm_duration(len(self.audio), self.audio_format) self.duration = get_pcm_duration(len(self.audio), self.audio_format)
elif audio_type == AUDIO_TYPE_NP: elif audio_type == AUDIO_TYPE_NP:
self.duration = get_np_duration(len(self.audio), self.audio_format) self.duration = get_np_duration(len(self.audio), self.audio_format)
else: else:
raise ValueError('Unsupported audio type: {}'.format(self.audio_type)) raise ValueError("Unsupported audio type: {}".format(self.audio_type))
def change_audio_type(self, new_audio_type, bitrate=None): def change_audio_type(self, new_audio_type, bitrate=None):
""" """
@ -102,7 +110,10 @@ class Sample:
""" """
if self.audio_type == new_audio_type: if self.audio_type == new_audio_type:
return return
if new_audio_type == AUDIO_TYPE_PCM and self.audio_type in SERIALIZABLE_AUDIO_TYPES: if (
new_audio_type == AUDIO_TYPE_PCM
and self.audio_type in SERIALIZABLE_AUDIO_TYPES
):
self.audio_format, audio = read_audio(self.audio_type, self.audio) self.audio_format, audio = read_audio(self.audio_type, self.audio)
self.audio.close() self.audio.close()
self.audio = audio self.audio = audio
@ -114,18 +125,27 @@ class Sample:
elif new_audio_type in SERIALIZABLE_AUDIO_TYPES: elif new_audio_type in SERIALIZABLE_AUDIO_TYPES:
self.change_audio_type(AUDIO_TYPE_PCM) self.change_audio_type(AUDIO_TYPE_PCM)
audio_bytes = io.BytesIO() audio_bytes = io.BytesIO()
write_audio(new_audio_type, audio_bytes, self.audio, audio_format=self.audio_format, bitrate=bitrate) write_audio(
new_audio_type,
audio_bytes,
self.audio,
audio_format=self.audio_format,
bitrate=bitrate,
)
audio_bytes.seek(0) audio_bytes.seek(0)
self.audio = audio_bytes self.audio = audio_bytes
else: else:
raise RuntimeError('Changing audio representation type from "{}" to "{}" not supported' raise RuntimeError(
.format(self.audio_type, new_audio_type)) 'Changing audio representation type from "{}" to "{}" not supported'.format(
self.audio_type, new_audio_type
)
)
self.audio_type = new_audio_type self.audio_type = new_audio_type
def _unpack_and_change_audio_type(sample_and_audio_type): def _unpack_and_change_audio_type(sample_and_audio_type):
packed_sample, audio_type, bitrate = sample_and_audio_type packed_sample, audio_type, bitrate = sample_and_audio_type
if hasattr(packed_sample, 'unpack'): if hasattr(packed_sample, "unpack"):
sample = packed_sample.unpack() sample = packed_sample.unpack()
else: else:
sample = packed_sample sample = packed_sample
@ -133,20 +153,31 @@ def _unpack_and_change_audio_type(sample_and_audio_type):
return sample return sample
def change_audio_types(packed_samples, audio_type=AUDIO_TYPE_PCM, bitrate=None, processes=None, process_ahead=None): def change_audio_types(
packed_samples,
audio_type=AUDIO_TYPE_PCM,
bitrate=None,
processes=None,
process_ahead=None,
):
with LimitingPool(processes=processes, process_ahead=process_ahead) as pool: with LimitingPool(processes=processes, process_ahead=process_ahead) as pool:
yield from pool.imap(_unpack_and_change_audio_type, map(lambda s: (s, audio_type, bitrate), packed_samples)) yield from pool.imap(
_unpack_and_change_audio_type,
map(lambda s: (s, audio_type, bitrate), packed_samples),
)
def get_loadable_audio_type_from_extension(ext): def get_loadable_audio_type_from_extension(ext):
return { return {
'.wav': AUDIO_TYPE_WAV, ".wav": AUDIO_TYPE_WAV,
'.opus': AUDIO_TYPE_OGG_OPUS, ".opus": AUDIO_TYPE_OGG_OPUS,
}.get(ext, None) }.get(ext, None)
def read_audio_format_from_wav_file(wav_file): def read_audio_format_from_wav_file(wav_file):
return AudioFormat(wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()) return AudioFormat(
wav_file.getframerate(), wav_file.getnchannels(), wav_file.getsampwidth()
)
def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT): def get_num_samples(pcm_buffer_size, audio_format=DEFAULT_FORMAT):
@ -163,13 +194,18 @@ def get_np_duration(np_len, audio_format=DEFAULT_FORMAT):
return np_len / audio_format.rate return np_len / audio_format.rate
def convert_audio(src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT): def convert_audio(
src_audio_path, dst_audio_path, file_type=None, audio_format=DEFAULT_FORMAT
):
import sox import sox
transformer = sox.Transformer() transformer = sox.Transformer()
transformer.set_output_format(file_type=file_type, transformer.set_output_format(
rate=audio_format.rate, file_type=file_type,
channels=audio_format.channels, rate=audio_format.rate,
bits=audio_format.width * 8) channels=audio_format.channels,
bits=audio_format.width * 8,
)
transformer.build(src_audio_path, dst_audio_path) transformer.build(src_audio_path, dst_audio_path)
@ -178,6 +214,7 @@ class AudioFile:
Audio data file wrapper that ensures that the file is loaded with the correct sample rate, channels, Audio data file wrapper that ensures that the file is loaded with the correct sample rate, channels,
and width, and converts the file on the fly otherwise. and width, and converts the file on the fly otherwise.
""" """
def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT): def __init__(self, audio_path, as_path=False, audio_format=DEFAULT_FORMAT):
self.audio_path = audio_path self.audio_path = audio_path
self.audio_format = audio_format self.audio_format = audio_format
@ -188,8 +225,8 @@ class AudioFile:
self.tmp_src_file_path = None self.tmp_src_file_path = None
def __enter__(self): def __enter__(self):
if self.audio_path.endswith('.wav'): if self.audio_path.endswith(".wav"):
self.open_file = open_remote(self.audio_path, 'rb') self.open_file = open_remote(self.audio_path, "rb")
self.open_wav = wave.open(self.open_file) self.open_wav = wave.open(self.open_file)
if read_audio_format_from_wav_file(self.open_wav) == self.audio_format: if read_audio_format_from_wav_file(self.open_wav) == self.audio_format:
if self.as_path: if self.as_path:
@ -202,15 +239,20 @@ class AudioFile:
# If the format isn't right, copy the file to local tmp dir and do the conversion on disk # If the format isn't right, copy the file to local tmp dir and do the conversion on disk
if is_remote_path(self.audio_path): if is_remote_path(self.audio_path):
_, self.tmp_src_file_path = tempfile.mkstemp(suffix='.wav') _, self.tmp_src_file_path = tempfile.mkstemp(suffix=".wav")
copy_remote(self.audio_path, self.tmp_src_file_path, True) copy_remote(self.audio_path, self.tmp_src_file_path, True)
self.audio_path = self.tmp_src_file_path self.audio_path = self.tmp_src_file_path
_, self.tmp_file_path = tempfile.mkstemp(suffix='.wav') _, self.tmp_file_path = tempfile.mkstemp(suffix=".wav")
convert_audio(self.audio_path, self.tmp_file_path, file_type='wav', audio_format=self.audio_format) convert_audio(
self.audio_path,
self.tmp_file_path,
file_type="wav",
audio_format=self.audio_format,
)
if self.as_path: if self.as_path:
return self.tmp_file_path return self.tmp_file_path
self.open_wav = wave.open(self.tmp_file_path, 'rb') self.open_wav = wave.open(self.tmp_file_path, "rb")
return self.open_wav return self.open_wav
def __exit__(self, *args): def __exit__(self, *args):
@ -230,33 +272,49 @@ def read_frames(wav_file, frame_duration_ms=30, yield_remainder=False):
while True: while True:
try: try:
data = wav_file.readframes(frame_size) data = wav_file.readframes(frame_size)
if not yield_remainder and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms: if (
not yield_remainder
and get_pcm_duration(len(data), audio_format) * 1000 < frame_duration_ms
):
break break
yield data yield data
except EOFError: except EOFError:
break break
def read_frames_from_file(audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False): def read_frames_from_file(
audio_path, audio_format=DEFAULT_FORMAT, frame_duration_ms=30, yield_remainder=False
):
with AudioFile(audio_path, audio_format=audio_format) as wav_file: with AudioFile(audio_path, audio_format=audio_format) as wav_file:
for frame in read_frames(wav_file, frame_duration_ms=frame_duration_ms, yield_remainder=yield_remainder): for frame in read_frames(
wav_file,
frame_duration_ms=frame_duration_ms,
yield_remainder=yield_remainder,
):
yield frame yield frame
def vad_split(audio_frames, def vad_split(
audio_format=DEFAULT_FORMAT, audio_frames,
num_padding_frames=10, audio_format=DEFAULT_FORMAT,
threshold=0.5, num_padding_frames=10,
aggressiveness=3): threshold=0.5,
aggressiveness=3,
):
from webrtcvad import Vad # pylint: disable=import-outside-toplevel from webrtcvad import Vad # pylint: disable=import-outside-toplevel
if audio_format.channels != 1: if audio_format.channels != 1:
raise ValueError('VAD-splitting requires mono samples') raise ValueError("VAD-splitting requires mono samples")
if audio_format.width != 2: if audio_format.width != 2:
raise ValueError('VAD-splitting requires 16 bit samples') raise ValueError("VAD-splitting requires 16 bit samples")
if audio_format.rate not in [8000, 16000, 32000, 48000]: if audio_format.rate not in [8000, 16000, 32000, 48000]:
raise ValueError('VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000') raise ValueError(
"VAD-splitting only supported for sample rates 8000, 16000, 32000, or 48000"
)
if aggressiveness not in [0, 1, 2, 3]: if aggressiveness not in [0, 1, 2, 3]:
raise ValueError('VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3') raise ValueError(
"VAD-splitting aggressiveness mode has to be one of 0, 1, 2, or 3"
)
ring_buffer = collections.deque(maxlen=num_padding_frames) ring_buffer = collections.deque(maxlen=num_padding_frames)
triggered = False triggered = False
vad = Vad(int(aggressiveness)) vad = Vad(int(aggressiveness))
@ -266,7 +324,9 @@ def vad_split(audio_frames,
for frame_index, frame in enumerate(audio_frames): for frame_index, frame in enumerate(audio_frames):
frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000 frame_duration_ms = get_pcm_duration(len(frame), audio_format) * 1000
if int(frame_duration_ms) not in [10, 20, 30]: if int(frame_duration_ms) not in [10, 20, 30]:
raise ValueError('VAD-splitting only supported for frame durations 10, 20, or 30 ms') raise ValueError(
"VAD-splitting only supported for frame durations 10, 20, or 30 ms"
)
is_speech = vad.is_speech(frame, audio_format.rate) is_speech = vad.is_speech(frame, audio_format.rate)
if not triggered: if not triggered:
ring_buffer.append((frame, is_speech)) ring_buffer.append((frame, is_speech))
@ -282,23 +342,23 @@ def vad_split(audio_frames,
num_unvoiced = len([f for f, speech in ring_buffer if not speech]) num_unvoiced = len([f for f, speech in ring_buffer if not speech])
if num_unvoiced > threshold * ring_buffer.maxlen: if num_unvoiced > threshold * ring_buffer.maxlen:
triggered = False triggered = False
yield b''.join(voiced_frames), \ yield b"".join(voiced_frames), frame_duration_ms * max(
frame_duration_ms * max(0, frame_index - len(voiced_frames)), \ 0, frame_index - len(voiced_frames)
frame_duration_ms * frame_index ), frame_duration_ms * frame_index
ring_buffer.clear() ring_buffer.clear()
voiced_frames = [] voiced_frames = []
if len(voiced_frames) > 0: if len(voiced_frames) > 0:
yield b''.join(voiced_frames), \ yield b"".join(voiced_frames), frame_duration_ms * (
frame_duration_ms * (frame_index - len(voiced_frames)), \ frame_index - len(voiced_frames)
frame_duration_ms * (frame_index + 1) ), frame_duration_ms * (frame_index + 1)
def pack_number(n, num_bytes): def pack_number(n, num_bytes):
return n.to_bytes(num_bytes, 'big', signed=False) return n.to_bytes(num_bytes, "big", signed=False)
def unpack_number(data): def unpack_number(data):
return int.from_bytes(data, 'big', signed=False) return int.from_bytes(data, "big", signed=False)
def get_opus_frame_size(rate): def get_opus_frame_size(rate):
@ -308,7 +368,8 @@ def get_opus_frame_size(rate):
def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT, bitrate=None): def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT, bitrate=None):
frame_size = get_opus_frame_size(audio_format.rate) frame_size = get_opus_frame_size(audio_format.rate)
import opuslib # pylint: disable=import-outside-toplevel import opuslib # pylint: disable=import-outside-toplevel
encoder = opuslib.Encoder(audio_format.rate, audio_format.channels, 'audio')
encoder = opuslib.Encoder(audio_format.rate, audio_format.channels, "audio")
if bitrate is not None: if bitrate is not None:
encoder.bitrate = bitrate encoder.bitrate = bitrate
chunk_size = frame_size * audio_format.channels * audio_format.width chunk_size = frame_size * audio_format.channels * audio_format.width
@ -317,10 +378,10 @@ def write_opus(opus_file, audio_data, audio_format=DEFAULT_FORMAT, bitrate=None)
opus_file.write(pack_number(audio_format.channels, OPUS_CHANNELS_SIZE)) opus_file.write(pack_number(audio_format.channels, OPUS_CHANNELS_SIZE))
opus_file.write(pack_number(audio_format.width, OPUS_WIDTH_SIZE)) opus_file.write(pack_number(audio_format.width, OPUS_WIDTH_SIZE))
for i in range(0, len(audio_data), chunk_size): for i in range(0, len(audio_data), chunk_size):
chunk = audio_data[i:i + chunk_size] chunk = audio_data[i : i + chunk_size]
# Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer # Preventing non-deterministic encoding results from uninitialized remainder of the encoder buffer
if len(chunk) < chunk_size: if len(chunk) < chunk_size:
chunk = chunk + b'\0' * (chunk_size - len(chunk)) chunk = chunk + b"\0" * (chunk_size - len(chunk))
encoded = encoder.encode(chunk, frame_size) encoded = encoder.encode(chunk, frame_size)
opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE)) opus_file.write(pack_number(len(encoded), OPUS_CHUNK_LEN_SIZE))
opus_file.write(encoded) opus_file.write(encoded)
@ -339,6 +400,7 @@ def read_opus(opus_file):
pcm_buffer_size, audio_format = read_opus_header(opus_file) pcm_buffer_size, audio_format = read_opus_header(opus_file)
frame_size = get_opus_frame_size(audio_format.rate) frame_size = get_opus_frame_size(audio_format.rate)
import opuslib # pylint: disable=import-outside-toplevel import opuslib # pylint: disable=import-outside-toplevel
decoder = opuslib.Decoder(audio_format.rate, audio_format.channels) decoder = opuslib.Decoder(audio_format.rate, audio_format.channels)
audio_data = bytearray() audio_data = bytearray()
while len(audio_data) < pcm_buffer_size: while len(audio_data) < pcm_buffer_size:
@ -357,34 +419,30 @@ def read_ogg_opus(ogg_file):
opusfile = pyogg.opus.op_open_memory( opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer), ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer), len(ogg_file_buffer),
ctypes.pointer(error) ctypes.pointer(error),
) )
if error.value != 0: if error.value != 0:
raise ValueError( raise ValueError(
("Ogg/Opus buffer could not be read." ("Ogg/Opus buffer could not be read." "Error code: {}").format(error.value)
"Error code: {}").format(error.value)
) )
channel_count = pyogg.opus.op_channel_count(opusfile, -1) channel_count = pyogg.opus.op_channel_count(opusfile, -1)
sample_rate = 48000 # opus files are always 48kHz sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit sample_width = 2 # always 16-bit
audio_format = AudioFormat(sample_rate, channel_count, sample_width) audio_format = AudioFormat(sample_rate, channel_count, sample_width)
# Allocate sufficient memory to store the entire PCM # Allocate sufficient memory to store the entire PCM
pcm_size = pyogg.opus.op_pcm_total(opusfile, -1) pcm_size = pyogg.opus.op_pcm_total(opusfile, -1)
Buf = pyogg.opus.opus_int16*(pcm_size*channel_count) Buf = pyogg.opus.opus_int16 * (pcm_size * channel_count)
buf = Buf() buf = Buf()
# Create a pointer to the newly allocated memory. It # Create a pointer to the newly allocated memory. It
# seems we can only do pointer arithmetic on void # seems we can only do pointer arithmetic on void
# pointers. See # pointers. See
# https://mattgwwalker.wordpress.com/2020/05/30/pointer-manipulation-in-python/ # https://mattgwwalker.wordpress.com/2020/05/30/pointer-manipulation-in-python/
buf_ptr = ctypes.cast( buf_ptr = ctypes.cast(ctypes.pointer(buf), ctypes.c_void_p)
ctypes.pointer(buf), assert buf_ptr.value is not None # for mypy
ctypes.c_void_p
)
assert buf_ptr.value is not None # for mypy
buf_ptr_zero = buf_ptr.value buf_ptr_zero = buf_ptr.value
#: Bytes per sample #: Bytes per sample
@ -396,38 +454,24 @@ def read_ogg_opus(ogg_file):
while True: while True:
# Calculate remaining buffer size # Calculate remaining buffer size
remaining_buffer = ( remaining_buffer = (
len(buf) # int len(buf) - (buf_ptr.value - buf_ptr_zero) // bytes_per_sample # int
- (buf_ptr.value - buf_ptr_zero) // bytes_per_sample
) )
# Convert buffer pointer to the desired type # Convert buffer pointer to the desired type
ptr = ctypes.cast( ptr = ctypes.cast(buf_ptr, ctypes.POINTER(pyogg.opus.opus_int16))
buf_ptr,
ctypes.POINTER(pyogg.opus.opus_int16)
)
# Read the next section of PCM # Read the next section of PCM
ns = pyogg.opus.op_read( ns = pyogg.opus.op_read(opusfile, ptr, remaining_buffer, pyogg.ogg.c_int_p())
opusfile,
ptr,
remaining_buffer,
pyogg.ogg.c_int_p()
)
# Check for errors # Check for errors
if ns < 0: if ns < 0:
raise ValueError( raise ValueError(
"Error while reading OggOpus buffer. "+ "Error while reading OggOpus buffer. " + "Error code: {}".format(ns)
"Error code: {}".format(ns)
) )
# Increment the pointer # Increment the pointer
buf_ptr.value += ( buf_ptr.value += ns * bytes_per_sample * channel_count
ns assert buf_ptr.value is not None # for mypy
* bytes_per_sample
* channel_count
)
assert buf_ptr.value is not None # for mypy
samples += ns samples += ns
@ -448,7 +492,7 @@ def read_ogg_opus(ogg_file):
def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT): def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
# wav_file is already a file-pointer here # wav_file is already a file-pointer here
with wave.open(wav_file, 'wb') as wav_file_writer: with wave.open(wav_file, "wb") as wav_file_writer:
wav_file_writer.setframerate(audio_format.rate) wav_file_writer.setframerate(audio_format.rate)
wav_file_writer.setnchannels(audio_format.channels) wav_file_writer.setnchannels(audio_format.channels)
wav_file_writer.setsampwidth(audio_format.width) wav_file_writer.setsampwidth(audio_format.width)
@ -457,7 +501,7 @@ def write_wav(wav_file, pcm_data, audio_format=DEFAULT_FORMAT):
def read_wav(wav_file): def read_wav(wav_file):
wav_file.seek(0) wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader: with wave.open(wav_file, "rb") as wav_file_reader:
audio_format = read_audio_format_from_wav_file(wav_file_reader) audio_format = read_audio_format_from_wav_file(wav_file_reader)
pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes()) pcm_data = wav_file_reader.readframes(wav_file_reader.getnframes())
return audio_format, pcm_data return audio_format, pcm_data
@ -470,20 +514,24 @@ def read_audio(audio_type, audio_file):
return read_opus(audio_file) return read_opus(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS: if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus(audio_file) return read_ogg_opus(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type)) raise ValueError("Unsupported audio type: {}".format(audio_type))
def write_audio(audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, bitrate=None): def write_audio(
audio_type, audio_file, pcm_data, audio_format=DEFAULT_FORMAT, bitrate=None
):
if audio_type == AUDIO_TYPE_WAV: if audio_type == AUDIO_TYPE_WAV:
return write_wav(audio_file, pcm_data, audio_format=audio_format) return write_wav(audio_file, pcm_data, audio_format=audio_format)
if audio_type == AUDIO_TYPE_OPUS: if audio_type == AUDIO_TYPE_OPUS:
return write_opus(audio_file, pcm_data, audio_format=audio_format, bitrate=bitrate) return write_opus(
raise ValueError('Unsupported audio type: {}'.format(audio_type)) audio_file, pcm_data, audio_format=audio_format, bitrate=bitrate
)
raise ValueError("Unsupported audio type: {}".format(audio_type))
def read_wav_duration(wav_file): def read_wav_duration(wav_file):
wav_file.seek(0) wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader: with wave.open(wav_file, "rb") as wav_file_reader:
return wav_file_reader.getnframes() / wav_file_reader.getframerate() return wav_file_reader.getnframes() / wav_file_reader.getframerate()
@ -499,19 +547,18 @@ def read_ogg_opus_duration(ogg_file):
opusfile = pyogg.opus.op_open_memory( opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer), ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer), len(ogg_file_buffer),
ctypes.pointer(error) ctypes.pointer(error),
) )
if error.value != 0: if error.value != 0:
raise ValueError( raise ValueError(
("Ogg/Opus buffer could not be read." ("Ogg/Opus buffer could not be read." "Error code: {}").format(error.value)
"Error code: {}").format(error.value)
) )
pcm_buffer_size = pyogg.opus.op_pcm_total(opusfile, -1) pcm_buffer_size = pyogg.opus.op_pcm_total(opusfile, -1)
channel_count = pyogg.opus.op_channel_count(opusfile, -1) channel_count = pyogg.opus.op_channel_count(opusfile, -1)
sample_rate = 48000 # opus files are always 48kHz sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit sample_width = 2 # always 16-bit
audio_format = AudioFormat(sample_rate, channel_count, sample_width) audio_format = AudioFormat(sample_rate, channel_count, sample_width)
pyogg.opus.op_free(opusfile) pyogg.opus.op_free(opusfile)
return get_pcm_duration(pcm_buffer_size, audio_format) return get_pcm_duration(pcm_buffer_size, audio_format)
@ -524,12 +571,12 @@ def read_duration(audio_type, audio_file):
return read_opus_duration(audio_file) return read_opus_duration(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS: if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus_duration(audio_file) return read_ogg_opus_duration(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type)) raise ValueError("Unsupported audio type: {}".format(audio_type))
def read_wav_format(wav_file): def read_wav_format(wav_file):
wav_file.seek(0) wav_file.seek(0)
with wave.open(wav_file, 'rb') as wav_file_reader: with wave.open(wav_file, "rb") as wav_file_reader:
return read_audio_format_from_wav_file(wav_file_reader) return read_audio_format_from_wav_file(wav_file_reader)
@ -545,20 +592,19 @@ def read_ogg_opus_format(ogg_file):
opusfile = pyogg.opus.op_open_memory( opusfile = pyogg.opus.op_open_memory(
ubyte_array.from_buffer(ogg_file_buffer), ubyte_array.from_buffer(ogg_file_buffer),
len(ogg_file_buffer), len(ogg_file_buffer),
ctypes.pointer(error) ctypes.pointer(error),
) )
if error.value != 0: if error.value != 0:
raise ValueError( raise ValueError(
("Ogg/Opus buffer could not be read." ("Ogg/Opus buffer could not be read." "Error code: {}").format(error.value)
"Error code: {}").format(error.value)
) )
channel_count = pyogg.opus.op_channel_count(opusfile, -1) channel_count = pyogg.opus.op_channel_count(opusfile, -1)
pyogg.opus.op_free(opusfile) pyogg.opus.op_free(opusfile)
sample_rate = 48000 # opus files are always 48kHz sample_rate = 48000 # opus files are always 48kHz
sample_width = 2 # always 16-bit sample_width = 2 # always 16-bit
return AudioFormat(sample_rate, channel_count, sample_width) return AudioFormat(sample_rate, channel_count, sample_width)
@ -569,12 +615,12 @@ def read_format(audio_type, audio_file):
return read_opus_format(audio_file) return read_opus_format(audio_file)
if audio_type == AUDIO_TYPE_OGG_OPUS: if audio_type == AUDIO_TYPE_OGG_OPUS:
return read_ogg_opus_format(audio_file) return read_ogg_opus_format(audio_file)
raise ValueError('Unsupported audio type: {}'.format(audio_type)) raise ValueError("Unsupported audio type: {}".format(audio_type))
def get_dtype(audio_format): def get_dtype(audio_format):
if audio_format.width not in [1, 2, 4]: if audio_format.width not in [1, 2, 4]:
raise ValueError('Unsupported sample width: {}'.format(audio_format.width)) raise ValueError("Unsupported sample width: {}".format(audio_format.width))
return [None, np.int8, np.int16, None, np.int32][audio_format.width] return [None, np.int8, np.int16, None, np.int32][audio_format.width]
@ -589,7 +635,7 @@ def pcm_to_np(pcm_data, audio_format=DEFAULT_FORMAT):
# Read interleaved channels # Read interleaved channels
nchannels = audio_format.channels nchannels = audio_format.channels
samples = samples.reshape((int(len(samples)/nchannels), nchannels)) samples = samples.reshape((int(len(samples) / nchannels), nchannels))
# Convert to 0.0-1.0 range # Convert to 0.0-1.0 range
samples = samples.astype(np.float32) / np.iinfo(dtype).max samples = samples.astype(np.float32) / np.iinfo(dtype).max
@ -624,4 +670,7 @@ def gain_db_to_ratio(gain_db):
def normalize_audio(sample_data, dbfs=3.0103): def normalize_audio(sample_data, dbfs=3.0103):
return np.maximum(np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), 1.0), -1.0) return np.maximum(
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)), 1.0),
-1.0,
)

View File

@ -1,18 +1,33 @@
import os
import re
import math import math
import os
import random import random
import resampy import re
import numpy as np from multiprocessing import Process, Queue
from multiprocessing import Queue, Process import numpy as np
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS import resampy
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
from .sample_collections import samples_from_source, unpack_maybe from .audio import (
AUDIO_TYPE_NP,
AUDIO_TYPE_OPUS,
AUDIO_TYPE_PCM,
gain_db_to_ratio,
max_dbfs,
normalize_audio,
)
from .helpers import (
MEGABYTE,
LimitingPool,
float_range,
int_range,
pick_value_from_range,
tf_pick_value_from_range,
)
from .logging import log_info from .logging import log_info
from .sample_collections import samples_from_source, unpack_maybe
BUFFER_SIZE = 1 * MEGABYTE BUFFER_SIZE = 1 * MEGABYTE
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$') SPEC_PARSER = re.compile(r"^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$")
class Augmentation: class Augmentation:
@ -32,10 +47,10 @@ class SampleAugmentation(Augmentation):
class GraphAugmentation(Augmentation): class GraphAugmentation(Augmentation):
def __init__(self, p=1.0, domain='spectrogram'): def __init__(self, p=1.0, domain="spectrogram"):
super(GraphAugmentation, self).__init__(p) super(GraphAugmentation, self).__init__(p)
if domain not in ['signal', 'spectrogram', 'features']: if domain not in ["signal", "spectrogram", "features"]:
raise ValueError('Unsupported augmentation domain: {}'.format(domain)) raise ValueError("Unsupported augmentation domain: {}".format(domain))
self.domain = domain self.domain = domain
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
@ -43,19 +58,31 @@ class GraphAugmentation(Augmentation):
def apply_with_probability(self, tensor, transcript=None, clock=0.0): def apply_with_probability(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
rv = tf.random.stateless_uniform([], seed=(clock * tf.int32.min, clock * tf.int32.max))
return tf.cond(tf.less(rv, self.probability), rv = tf.random.stateless_uniform(
lambda: self.apply(tensor, transcript=transcript, clock=clock), [], seed=(clock * tf.int32.min, clock * tf.int32.max)
lambda: tensor) )
return tf.cond(
tf.less(rv, self.probability),
lambda: self.apply(tensor, transcript=transcript, clock=clock),
lambda: tensor,
)
def maybe_apply(self, domain, tensor, transcript=None, clock=0.0): def maybe_apply(self, domain, tensor, transcript=None, clock=0.0):
if domain == self.domain: if domain == self.domain:
return self.apply_with_probability(tensor, transcript=transcript, clock=clock) return self.apply_with_probability(
tensor, transcript=transcript, clock=clock
)
return tensor return tensor
def units_per_ms(self): def units_per_ms(self):
from .flags import FLAGS # pylint: disable=import-outside-toplevel from .flags import FLAGS # pylint: disable=import-outside-toplevel
return FLAGS.audio_sample_rate / 1000.0 if self.domain == 'signal' else 1.0 / FLAGS.feature_win_step
return (
FLAGS.audio_sample_rate / 1000.0
if self.domain == "signal"
else 1.0 / FLAGS.feature_win_step
)
def parse_augmentation(augmentation_spec): def parse_augmentation(augmentation_spec):
@ -73,24 +100,34 @@ def parse_augmentation(augmentation_spec):
""" """
match = SPEC_PARSER.match(augmentation_spec) match = SPEC_PARSER.match(augmentation_spec)
if not match: if not match:
raise ValueError('Augmentation specification has wrong format') raise ValueError("Augmentation specification has wrong format")
cls_name = ''.join(map(lambda p: p[0].upper() + p[1:], match.group('cls').split('_'))) cls_name = "".join(
map(lambda p: p[0].upper() + p[1:], match.group("cls").split("_"))
)
augmentation_cls = globals()[cls_name] if cls_name in globals() else None augmentation_cls = globals()[cls_name] if cls_name in globals() else None
if augmentation_cls is None or not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation: if (
raise ValueError('Unknown augmentation: {}'.format(cls_name)) augmentation_cls is None
parameters = match.group('params') or not issubclass(augmentation_cls, Augmentation)
parameters = [] if parameters is None else parameters.split(',') or augmentation_cls == Augmentation
):
raise ValueError("Unknown augmentation: {}".format(cls_name))
parameters = match.group("params")
parameters = [] if parameters is None else parameters.split(",")
args = [] args = []
kwargs = {} kwargs = {}
for parameter in parameters: for parameter in parameters:
pair = tuple(list(map(str.strip, (parameter.split('='))))) pair = tuple(list(map(str.strip, (parameter.split("=")))))
if len(pair) == 1: if len(pair) == 1:
args.append(pair) args.append(pair)
elif len(pair) == 2: elif len(pair) == 2:
kwargs[pair[0]] = pair[1] kwargs[pair[0]] = pair[1]
else: else:
raise ValueError('Unable to parse augmentation value assignment') raise ValueError("Unable to parse augmentation value assignment")
log_info('Processed augmentation type: [{}] with parameter settings: {}'.format(augmentation_cls.__name__, kwargs)) log_info(
"Processed augmentation type: [{}] with parameter settings: {}".format(
augmentation_cls.__name__, kwargs
)
)
return augmentation_cls(*args, **kwargs) return augmentation_cls(*args, **kwargs)
@ -110,7 +147,9 @@ def parse_augmentations(augmentation_specs):
return list(map(parse_augmentation, augmentation_specs or [])) return list(map(parse_augmentation, augmentation_specs or []))
def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, clock=0.0): def apply_graph_augmentations(
domain, tensor, augmentations, transcript=None, clock=0.0
):
""" """
Augments training sample tensor of a certain domain with matching augmentations of passed list. Augments training sample tensor of a certain domain with matching augmentations of passed list.
@ -134,7 +173,9 @@ def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, cl
if augmentations: if augmentations:
for augmentation in augmentations: for augmentation in augmentations:
if isinstance(augmentation, GraphAugmentation): if isinstance(augmentation, GraphAugmentation):
tensor = augmentation.maybe_apply(domain, tensor, transcript=transcript, clock=clock) tensor = augmentation.maybe_apply(
domain, tensor, transcript=transcript, clock=clock
)
return tensor return tensor
@ -168,13 +209,15 @@ def _augment_sample(timed_sample, context=None):
return sample return sample
def apply_sample_augmentations(samples, def apply_sample_augmentations(
augmentations, samples,
audio_type=AUDIO_TYPE_NP, augmentations,
buffering=BUFFER_SIZE, audio_type=AUDIO_TYPE_NP,
process_ahead=None, buffering=BUFFER_SIZE,
clock=0.0, process_ahead=None,
final_clock=None): clock=0.0,
final_clock=None,
):
""" """
Prepares samples for being used during training. Prepares samples for being used during training.
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type. This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
@ -201,20 +244,27 @@ def apply_sample_augmentations(samples,
------- -------
iterable of util.sample_collections.LabeledSample or util.audio.Sample iterable of util.sample_collections.LabeledSample or util.audio.Sample
""" """
def timed_samples(): def timed_samples():
if final_clock is None: if final_clock is None:
for sample in samples: for sample in samples:
yield sample, clock yield sample, clock
else: else:
for sample_index, sample in enumerate(samples): for sample_index, sample in enumerate(samples):
sample_clock = clock + (final_clock - clock) * (sample_index / len(samples)) sample_clock = clock + (final_clock - clock) * (
sample_index / len(samples)
)
yield sample, sample_clock yield sample, sample_clock
assert 0.0 <= clock <= 1.0 assert 0.0 <= clock <= 1.0
if final_clock is not None: if final_clock is not None:
assert 0.0 <= final_clock <= 1.0 assert 0.0 <= final_clock <= 1.0
assert clock <= final_clock assert clock <= final_clock
augmentations = [aug for aug in augmentations if isinstance(aug, SampleAugmentation)] if augmentations else [] augmentations = (
[aug for aug in augmentations if isinstance(aug, SampleAugmentation)]
if augmentations
else []
)
try: try:
for augmentation in augmentations: for augmentation in augmentations:
augmentation.start(buffering=buffering) augmentation.start(buffering=buffering)
@ -223,9 +273,11 @@ def apply_sample_augmentations(samples,
for timed_sample in timed_samples(): for timed_sample in timed_samples():
yield _load_and_augment_sample(timed_sample, context=context) yield _load_and_augment_sample(timed_sample, context=context)
else: else:
with LimitingPool(process_ahead=process_ahead, with LimitingPool(
initializer=_init_augmentation_worker, process_ahead=process_ahead,
initargs=(context,)) as pool: initializer=_init_augmentation_worker,
initargs=(context,),
) as pool:
yield from pool.imap(_load_and_augment_sample, timed_samples()) yield from pool.imap(_load_and_augment_sample, timed_samples())
finally: finally:
for augmentation in augmentations: for augmentation in augmentations:
@ -247,6 +299,7 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
class Overlay(SampleAugmentation): class Overlay(SampleAugmentation):
"""See "Overlay augmentation" in training documentation""" """See "Overlay augmentation" in training documentation"""
def __init__(self, source, p=1.0, snr=3.0, layers=1): def __init__(self, source, p=1.0, snr=3.0, layers=1):
super(Overlay, self).__init__(p) super(Overlay, self).__init__(p)
self.source = source self.source = source
@ -257,10 +310,14 @@ class Overlay(SampleAugmentation):
self.enqueue_process = None self.enqueue_process = None
def start(self, buffering=BUFFER_SIZE): def start(self, buffering=BUFFER_SIZE):
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count()))) self.queue = Queue(
self.enqueue_process = Process(target=_enqueue_overlay_samples, max(1, math.floor(self.probability * self.layers[1] * os.cpu_count()))
args=(self.source, self.queue), )
kwargs={'buffering': buffering}) self.enqueue_process = Process(
target=_enqueue_overlay_samples,
args=(self.source, self.queue),
kwargs={"buffering": buffering},
)
self.enqueue_process.start() self.enqueue_process.start()
def apply(self, sample, clock=0.0): def apply(self, sample, clock=0.0):
@ -280,11 +337,15 @@ class Overlay(SampleAugmentation):
n_required = len(audio) - overlay_offset n_required = len(audio) - overlay_offset
n_current = len(self.current_sample) n_current = len(self.current_sample)
if n_required >= n_current: # take it completely if n_required >= n_current: # take it completely
overlay_data[overlay_offset:overlay_offset + n_current] += self.current_sample overlay_data[
overlay_offset : overlay_offset + n_current
] += self.current_sample
overlay_offset += n_current overlay_offset += n_current
self.current_sample = None self.current_sample = None
else: # take required slice from head and keep tail for next layer or sample else: # take required slice from head and keep tail for next layer or sample
overlay_data[overlay_offset:overlay_offset + n_required] += self.current_sample[0:n_required] overlay_data[
overlay_offset : overlay_offset + n_required
] += self.current_sample[0:n_required]
overlay_offset += n_required overlay_offset += n_required
self.current_sample = self.current_sample[n_required:] self.current_sample = self.current_sample[n_required:]
snr_db = pick_value_from_range(self.snr, clock=clock) snr_db = pick_value_from_range(self.snr, clock=clock)
@ -303,18 +364,24 @@ class Overlay(SampleAugmentation):
class Codec(SampleAugmentation): class Codec(SampleAugmentation):
"""See "Codec augmentation" in training documentation""" """See "Codec augmentation" in training documentation"""
def __init__(self, p=1.0, bitrate=3200): def __init__(self, p=1.0, bitrate=3200):
super(Codec, self).__init__(p) super(Codec, self).__init__(p)
self.bitrate = int_range(bitrate) self.bitrate = int_range(bitrate)
def apply(self, sample, clock=0.0): def apply(self, sample, clock=0.0):
bitrate = pick_value_from_range(self.bitrate, clock=clock) bitrate = pick_value_from_range(self.bitrate, clock=clock)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_PCM) # decoding to ensure it has to get encoded again sample.change_audio_type(
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream new_audio_type=AUDIO_TYPE_PCM
) # decoding to ensure it has to get encoded again
sample.change_audio_type(
new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate
) # will get decoded again downstream
class Reverb(SampleAugmentation): class Reverb(SampleAugmentation):
"""See "Reverb augmentation" in training documentation""" """See "Reverb augmentation" in training documentation"""
def __init__(self, p=1.0, delay=20.0, decay=10.0): def __init__(self, p=1.0, delay=20.0, decay=10.0):
super(Reverb, self).__init__(p) super(Reverb, self).__init__(p)
self.delay = float_range(delay) self.delay = float_range(delay)
@ -331,13 +398,17 @@ class Reverb(SampleAugmentation):
primes = [17, 19, 23, 29, 31] primes = [17, 19, 23, 29, 31]
for delay_prime in primes: # primes to minimize comb filter interference for delay_prime in primes: # primes to minimize comb filter interference
layer = np.copy(audio) layer = np.copy(audio)
n_delay = math.floor(delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0) n_delay = math.floor(
n_delay = max(16, n_delay) # 16 samples minimum to avoid performance trap and risk of division by zero delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0
)
n_delay = max(
16, n_delay
) # 16 samples minimum to avoid performance trap and risk of division by zero
for w_index in range(0, math.floor(len(audio) / n_delay)): for w_index in range(0, math.floor(len(audio) / n_delay)):
w1 = w_index * n_delay w1 = w_index * n_delay
w2 = (w_index + 1) * n_delay w2 = (w_index + 1) * n_delay
width = min(len(audio) - w2, n_delay) # last window could be smaller width = min(len(audio) - w2, n_delay) # last window could be smaller
layer[w2:w2 + width] += decay * layer[w1:w1 + width] layer[w2 : w2 + width] += decay * layer[w1 : w1 + width]
result += layer result += layer
audio = normalize_audio(result, dbfs=orig_dbfs) audio = normalize_audio(result, dbfs=orig_dbfs)
sample.audio = np.array(audio, dtype=np.float32) sample.audio = np.array(audio, dtype=np.float32)
@ -345,6 +416,7 @@ class Reverb(SampleAugmentation):
class Resample(SampleAugmentation): class Resample(SampleAugmentation):
"""See "Resample augmentation" in training documentation""" """See "Resample augmentation" in training documentation"""
def __init__(self, p=1.0, rate=8000): def __init__(self, p=1.0, rate=8000):
super(Resample, self).__init__(p) super(Resample, self).__init__(p)
self.rate = int_range(rate) self.rate = int_range(rate)
@ -353,8 +425,12 @@ class Resample(SampleAugmentation):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
rate = pick_value_from_range(self.rate, clock=clock) rate = pick_value_from_range(self.rate, clock=clock)
orig_len = len(sample.audio) orig_len = len(sample.audio)
resampled = resampy.resample(sample.audio, sample.audio_format.rate, rate, axis=0, filter='kaiser_fast') resampled = resampy.resample(
sample.audio = resampy.resample(resampled, rate, sample.audio_format.rate, axis=0, filter='kaiser_fast')[:orig_len] sample.audio, sample.audio_format.rate, rate, axis=0, filter="kaiser_fast"
)
sample.audio = resampy.resample(
resampled, rate, sample.audio_format.rate, axis=0, filter="kaiser_fast"
)[:orig_len]
class NormalizeSampleRate(SampleAugmentation): class NormalizeSampleRate(SampleAugmentation):
@ -367,12 +443,19 @@ class NormalizeSampleRate(SampleAugmentation):
return return
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP) sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
sample.audio = resampy.resample(sample.audio, sample.audio_format.rate, self.rate, axis=0, filter='kaiser_fast') sample.audio = resampy.resample(
sample.audio,
sample.audio_format.rate,
self.rate,
axis=0,
filter="kaiser_fast",
)
sample.audio_format = sample.audio_format._replace(rate=self.rate) sample.audio_format = sample.audio_format._replace(rate=self.rate)
class Volume(SampleAugmentation): class Volume(SampleAugmentation):
"""See "Volume augmentation" in training documentation""" """See "Volume augmentation" in training documentation"""
def __init__(self, p=1.0, dbfs=3.0103): def __init__(self, p=1.0, dbfs=3.0103):
super(Volume, self).__init__(p) super(Volume, self).__init__(p)
self.target_dbfs = float_range(dbfs) self.target_dbfs = float_range(dbfs)
@ -385,55 +468,76 @@ class Volume(SampleAugmentation):
class Pitch(GraphAugmentation): class Pitch(GraphAugmentation):
"""See "Pitch augmentation" in training documentation""" """See "Pitch augmentation" in training documentation"""
def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)): def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)):
super(Pitch, self).__init__(p, domain='spectrogram') super(Pitch, self).__init__(p, domain="spectrogram")
self.pitch = float_range(pitch) self.pitch = float_range(pitch)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor) original_shape = tf.shape(tensor)
pitch = tf_pick_value_from_range(self.pitch, clock=clock) pitch = tf_pick_value_from_range(self.pitch, clock=clock)
new_freq_size = tf.cast(tf.cast(original_shape[2], tf.float32) * pitch, tf.int32) new_freq_size = tf.cast(
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size]) tf.cast(original_shape[2], tf.float32) * pitch, tf.int32
spectrogram_aug = tf.image.crop_to_bounding_box(spectrogram_aug, )
offset_height=0, spectrogram_aug = tf.image.resize_bilinear(
offset_width=0, tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size]
target_height=original_shape[1], )
target_width=tf.math.minimum(original_shape[2], new_freq_size)) spectrogram_aug = tf.image.crop_to_bounding_box(
spectrogram_aug = tf.cond(pitch < 1, spectrogram_aug,
lambda: tf.image.pad_to_bounding_box(spectrogram_aug, offset_height=0,
offset_height=0, offset_width=0,
offset_width=0, target_height=original_shape[1],
target_height=tf.shape(spectrogram_aug)[1], target_width=tf.math.minimum(original_shape[2], new_freq_size),
target_width=original_shape[2]), )
lambda: spectrogram_aug) spectrogram_aug = tf.cond(
pitch < 1,
lambda: tf.image.pad_to_bounding_box(
spectrogram_aug,
offset_height=0,
offset_width=0,
target_height=tf.shape(spectrogram_aug)[1],
target_width=original_shape[2],
),
lambda: spectrogram_aug,
)
return spectrogram_aug[:, :, :, 0] return spectrogram_aug[:, :, :, 0]
class Tempo(GraphAugmentation): class Tempo(GraphAugmentation):
"""See "Tempo augmentation" in training documentation""" """See "Tempo augmentation" in training documentation"""
def __init__(self, p=1.0, factor=1.1, max_time=-1): def __init__(self, p=1.0, factor=1.1, max_time=-1):
super(Tempo, self).__init__(p, domain='spectrogram') super(Tempo, self).__init__(p, domain="spectrogram")
self.factor = float_range(factor) self.factor = float_range(factor)
self.max_time = float(max_time) self.max_time = float(max_time)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
factor = tf_pick_value_from_range(self.factor, clock=clock) factor = tf_pick_value_from_range(self.factor, clock=clock)
original_shape = tf.shape(tensor) original_shape = tf.shape(tensor)
new_time_size = tf.cast(tf.cast(original_shape[1], tf.float32) / factor, tf.int32) new_time_size = tf.cast(
tf.cast(original_shape[1], tf.float32) / factor, tf.int32
)
if transcript is not None: if transcript is not None:
new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1]) new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1])
if self.max_time > 0: if self.max_time > 0:
new_time_size = tf.math.minimum(new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32)) new_time_size = tf.math.minimum(
spectrogram_aug = tf.image.resize_bilinear(tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]]) new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32)
)
spectrogram_aug = tf.image.resize_bilinear(
tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]]
)
return spectrogram_aug[:, :, :, 0] return spectrogram_aug[:, :, :, 0]
class Warp(GraphAugmentation): class Warp(GraphAugmentation):
"""See "Warp augmentation" in training documentation""" """See "Warp augmentation" in training documentation"""
def __init__(self, p=1.0, nt=1, nf=1, wt=0.1, wf=0.0): def __init__(self, p=1.0, nt=1, nf=1, wt=0.1, wf=0.0):
super(Warp, self).__init__(p, domain='spectrogram') super(Warp, self).__init__(p, domain="spectrogram")
self.num_t = int_range(nt) self.num_t = int_range(nt)
self.num_f = int_range(nf) self.num_f = int_range(nf)
self.warp_t = float_range(wt) self.warp_t = float_range(wt)
@ -441,6 +545,7 @@ class Warp(GraphAugmentation):
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor) original_shape = tf.shape(tensor)
size_t, size_f = original_shape[1], original_shape[2] size_t, size_f = original_shape[1], original_shape[2]
seed = (clock * tf.int32.min, clock * tf.int32.max) seed = (clock * tf.int32.min, clock * tf.int32.max)
@ -449,25 +554,43 @@ class Warp(GraphAugmentation):
def get_flows(n, size, warp): def get_flows(n, size, warp):
warp = tf_pick_value_from_range(warp, clock=clock) warp = tf_pick_value_from_range(warp, clock=clock)
warp = warp * tf.cast(size, dtype=tf.float32) / tf.cast(2 * (n + 1), dtype=tf.float32) warp = (
f = tf.random.stateless_normal([num_t, num_f], seed, mean=0.0, stddev=warp, dtype=tf.float32) warp
return tf.pad(f, tf.constant([[1, 1], [1, 1]]), 'CONSTANT') # zero flow at all edges * tf.cast(size, dtype=tf.float32)
/ tf.cast(2 * (n + 1), dtype=tf.float32)
)
f = tf.random.stateless_normal(
[num_t, num_f], seed, mean=0.0, stddev=warp, dtype=tf.float32
)
return tf.pad(
f, tf.constant([[1, 1], [1, 1]]), "CONSTANT"
) # zero flow at all edges
flows = tf.stack([get_flows(num_t, size_t, self.warp_t), get_flows(num_f, size_f, self.warp_f)], axis=2) flows = tf.stack(
[
get_flows(num_t, size_t, self.warp_t),
get_flows(num_f, size_f, self.warp_f),
],
axis=2,
)
flows = tf.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f]) flows = tf.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f])
spectrogram_aug = tf.contrib.image.dense_image_warp(tf.expand_dims(tensor, -1), flows) spectrogram_aug = tf.contrib.image.dense_image_warp(
tf.expand_dims(tensor, -1), flows
)
return tf.reshape(spectrogram_aug, shape=(1, -1, size_f)) return tf.reshape(spectrogram_aug, shape=(1, -1, size_f))
class FrequencyMask(GraphAugmentation): class FrequencyMask(GraphAugmentation):
"""See "Frequency mask augmentation" in training documentation""" """See "Frequency mask augmentation" in training documentation"""
def __init__(self, p=1.0, n=3, size=2): def __init__(self, p=1.0, n=3, size=2):
super(FrequencyMask, self).__init__(p, domain='spectrogram') super(FrequencyMask, self).__init__(p, domain="spectrogram")
self.n = int_range(n) # pylint: disable=invalid-name self.n = int_range(n) # pylint: disable=invalid-name
self.size = int_range(size) self.size = int_range(size)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
time_max = tf.shape(tensor)[1] time_max = tf.shape(tensor)[1]
freq_max = tf.shape(tensor)[2] freq_max = tf.shape(tensor)[2]
n = tf_pick_value_from_range(self.n, clock=clock) n = tf_pick_value_from_range(self.n, clock=clock)
@ -476,10 +599,21 @@ class FrequencyMask(GraphAugmentation):
size = tf_pick_value_from_range(self.size, clock=clock) size = tf_pick_value_from_range(self.size, clock=clock)
size = tf.math.maximum(1, tf.math.minimum(freq_max - 1, size)) size = tf.math.maximum(1, tf.math.minimum(freq_max - 1, size))
seed = tf.cast(clock * tf.int32.max, tf.int32) - i seed = tf.cast(clock * tf.int32.max, tf.int32) - i
f0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=freq_max - size, dtype=tf.dtypes.int32) f0 = tf.random.stateless_uniform(
freq_mask = tf.concat([tf.ones([1, time_max, f0]), (),
tf.zeros([1, time_max, size]), (-seed, seed),
tf.ones([1, time_max, freq_max - f0 - size])], axis=2) minval=0,
maxval=freq_max - size,
dtype=tf.dtypes.int32,
)
freq_mask = tf.concat(
[
tf.ones([1, time_max, f0]),
tf.zeros([1, time_max, size]),
tf.ones([1, time_max, freq_max - f0 - size]),
],
axis=2,
)
return i + 1, spectrogram_aug * freq_mask return i + 1, spectrogram_aug * freq_mask
return tf.while_loop(lambda i, spectrogram_aug: i < n, body, (0, tensor))[1] return tf.while_loop(lambda i, spectrogram_aug: i < n, body, (0, tensor))[1]
@ -487,29 +621,51 @@ class FrequencyMask(GraphAugmentation):
class TimeMask(GraphAugmentation): class TimeMask(GraphAugmentation):
"""See "Time mask augmentation" in training documentation""" """See "Time mask augmentation" in training documentation"""
def __init__(self, p=1.0, domain='spectrogram', n=3, size=10.0):
def __init__(self, p=1.0, domain="spectrogram", n=3, size=10.0):
super(TimeMask, self).__init__(p, domain=domain) super(TimeMask, self).__init__(p, domain=domain)
self.n = int_range(n) # pylint: disable=invalid-name self.n = int_range(n) # pylint: disable=invalid-name
self.size = float_range(size) self.size = float_range(size)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
time_max = tf.shape(tensor)[0 if self.domain == 'signal' else 1]
time_max = tf.shape(tensor)[0 if self.domain == "signal" else 1]
n = tf_pick_value_from_range(self.n, clock=clock) n = tf_pick_value_from_range(self.n, clock=clock)
def body(i, augmented): def body(i, augmented):
size = tf.cast(tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(), dtype=tf.int32) size = tf.cast(
tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(),
dtype=tf.int32,
)
size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size)) size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size))
seed = tf.cast(clock * tf.int32.max, tf.int32) - i seed = tf.cast(clock * tf.int32.max, tf.int32) - i
t0 = tf.random.stateless_uniform((), (-seed, seed), minval=0, maxval=time_max - size, dtype=tf.dtypes.int32) t0 = tf.random.stateless_uniform(
(),
(-seed, seed),
minval=0,
maxval=time_max - size,
dtype=tf.dtypes.int32,
)
rest = time_max - t0 - size rest = time_max - t0 - size
if self.domain == 'spectrogram': if self.domain == "spectrogram":
fm = tf.shape(tensor)[2] fm = tf.shape(tensor)[2]
time_mask = tf.concat([tf.ones([1, t0, fm]), tf.zeros([1, size, fm]), tf.ones([1, rest, fm])], axis=1) time_mask = tf.concat(
elif self.domain == 'signal': [
time_mask = tf.concat([tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0) tf.ones([1, t0, fm]),
tf.zeros([1, size, fm]),
tf.ones([1, rest, fm]),
],
axis=1,
)
elif self.domain == "signal":
time_mask = tf.concat(
[tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0
)
else: else:
time_mask = tf.concat([tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1) time_mask = tf.concat(
[tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1
)
return i + 1, augmented * time_mask return i + 1, augmented * time_mask
return tf.while_loop(lambda i, augmented: i < n, body, (0, tensor))[1] return tf.while_loop(lambda i, augmented: i < n, body, (0, tensor))[1]
@ -517,43 +673,55 @@ class TimeMask(GraphAugmentation):
class Dropout(GraphAugmentation): class Dropout(GraphAugmentation):
"""See "Dropout augmentation" in training documentation""" """See "Dropout augmentation" in training documentation"""
def __init__(self, p=1.0, domain='spectrogram', rate=0.05):
def __init__(self, p=1.0, domain="spectrogram", rate=0.05):
super(Dropout, self).__init__(p, domain=domain) super(Dropout, self).__init__(p, domain=domain)
self.rate = float_range(rate) self.rate = float_range(rate)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
rate = tf_pick_value_from_range(self.rate, clock=clock) rate = tf_pick_value_from_range(self.rate, clock=clock)
rate = tf.math.maximum(0.0, rate) rate = tf.math.maximum(0.0, rate)
factors = tf.random.stateless_uniform(tf.shape(tensor), factors = tf.random.stateless_uniform(
(clock * tf.int32.min, clock * tf.int32.max), tf.shape(tensor),
minval=0.0, (clock * tf.int32.min, clock * tf.int32.max),
maxval=1.0, minval=0.0,
dtype=tf.float32) maxval=1.0,
dtype=tf.float32,
)
return tensor * tf.math.sign(tf.math.floor(factors + rate)) return tensor * tf.math.sign(tf.math.floor(factors + rate))
class Add(GraphAugmentation): class Add(GraphAugmentation):
"""See "Add augmentation" in training documentation""" """See "Add augmentation" in training documentation"""
def __init__(self, p=1.0, domain='features', stddev=5):
def __init__(self, p=1.0, domain="features", stddev=5):
super(Add, self).__init__(p, domain=domain) super(Add, self).__init__(p, domain=domain)
self.stddev = float_range(stddev) self.stddev = float_range(stddev)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock) stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max) seed = (clock * tf.int32.min, clock * tf.int32.max)
return tensor + tf.random.stateless_normal(tf.shape(tensor), seed, mean=0.0, stddev=stddev) return tensor + tf.random.stateless_normal(
tf.shape(tensor), seed, mean=0.0, stddev=stddev
)
class Multiply(GraphAugmentation): class Multiply(GraphAugmentation):
"""See "Multiply augmentation" in training documentation""" """See "Multiply augmentation" in training documentation"""
def __init__(self, p=1.0, domain='features', stddev=5):
def __init__(self, p=1.0, domain="features", stddev=5):
super(Multiply, self).__init__(p, domain=domain) super(Multiply, self).__init__(p, domain=domain)
self.stddev = float_range(stddev) self.stddev = float_range(stddev)
def apply(self, tensor, transcript=None, clock=0.0): def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock) stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max) seed = (clock * tf.int32.min, clock * tf.int32.max)
return tensor * tf.random.stateless_normal(tf.shape(tensor), seed, mean=1.0, stddev=stddev) return tensor * tf.random.stateless_normal(
tf.shape(tensor), seed, mean=1.0, stddev=stddev
)

View File

@ -22,14 +22,31 @@ import csv
import os import os
import sys import sys
import unicodedata import unicodedata
from .io import open_remote from .io import open_remote
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True) parser.add_argument(
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true") "-csv",
parser.add_argument("-unicode", "--disable-unicode-variants", help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)", action="store_true") "--csv-files",
help="Str. Filenames as a comma separated list",
required=True,
)
parser.add_argument(
"-alpha",
"--alphabet-format",
help="Bool. Print in format for alphabet.txt",
action="store_true",
)
parser.add_argument(
"-unicode",
"--disable-unicode-variants",
help="Bool. DISABLE check for unicode consistency (use with --alphabet-format)",
action="store_true",
)
args = parser.parse_args() args = parser.parse_args()
in_files = args.csv_files.split(",") in_files = args.csv_files.split(",")
@ -46,11 +63,21 @@ def main():
if not args.disable_unicode_variants: if not args.disable_unicode_variants:
unicode_transcript = unicodedata.normalize("NFKC", row[2]) unicode_transcript = unicodedata.normalize("NFKC", row[2])
if row[2] != unicode_transcript: if row[2] != unicode_transcript:
print("Your input file", in_file, "contains at least one transript with unicode chars on more than one code-point: '{}'. Consider using NFKC normalization: unicodedata.normalize('NFKC', str).".format(row[2])) print(
"Your input file",
in_file,
"contains at least one transript with unicode chars on more than one code-point: '{}'. Consider using NFKC normalization: unicodedata.normalize('NFKC', str).".format(
row[2]
),
)
sys.exit(-1) sys.exit(-1)
all_text |= set(row[2]) all_text |= set(row[2])
except IndexError: except IndexError:
print("Your input file", in_file, "is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript") print(
"Your input file",
in_file,
"is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript",
)
sys.exit(-1) sys.exit(-1)
finally: finally:
csv_file.close() csv_file.close()
@ -63,5 +90,6 @@ def main():
else: else:
print(list(all_text)) print(list(all_text))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,9 +1,10 @@
import sys import sys
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from .flags import FLAGS from .flags import FLAGS
from .logging import log_info, log_error, log_warn from .logging import log_error, log_info, log_warn
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
@ -17,9 +18,11 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
# We explicitly allow the learning rate variable to be missing for backwards # We explicitly allow the learning rate variable to be missing for backwards
# compatibility with older checkpoints. # compatibility with older checkpoints.
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') lr_var = set(v for v in load_vars if v.op.name == "learning_rate")
if lr_var and ('learning_rate' not in vars_in_ckpt or if lr_var and (
(FLAGS.force_initialize_learning_rate and allow_lr_init)): "learning_rate" not in vars_in_ckpt
or (FLAGS.force_initialize_learning_rate and allow_lr_init)
):
assert len(lr_var) <= 1 assert len(lr_var) <= 1
load_vars -= lr_var load_vars -= lr_var
init_vars |= lr_var init_vars |= lr_var
@ -31,7 +34,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
missing_vars = set() missing_vars = set()
for v in load_vars: for v in load_vars:
if v.op.name not in vars_in_ckpt: if v.op.name not in vars_in_ckpt:
log_warn('CUDNN variable not found: %s' % (v.op.name)) log_warn("CUDNN variable not found: %s" % (v.op.name))
missing_vars.add(v) missing_vars.add(v)
init_vars.add(v) init_vars.add(v)
@ -40,10 +43,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
# Check that the only missing variables (i.e. those to be initialised) # Check that the only missing variables (i.e. those to be initialised)
# are the Adam moment tensors, if they aren't then we have an issue # are the Adam moment tensors, if they aren't then we have an issue
missing_var_names = [v.op.name for v in missing_vars] missing_var_names = [v.op.name for v in missing_vars]
if any('Adam' not in v for v in missing_var_names): if any("Adam" not in v for v in missing_var_names):
log_error('Tried to load a CuDNN RNN checkpoint but there were ' log_error(
'more missing variables than just the Adam moment ' "Tried to load a CuDNN RNN checkpoint but there were "
'tensors. Missing variables: {}'.format(missing_var_names)) "more missing variables than just the Adam moment "
"tensors. Missing variables: {}".format(missing_var_names)
)
sys.exit(1) sys.exit(1)
if allow_drop_layers and FLAGS.drop_source_layers > 0: if allow_drop_layers and FLAGS.drop_source_layers > 0:
@ -54,12 +59,16 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
# If we want to use all layers from the source model except # If we want to use all layers from the source model except
# the last one, we use this: drop_source_layers=1 # the last one, we use this: drop_source_layers=1
if FLAGS.drop_source_layers >= 6: if FLAGS.drop_source_layers >= 6:
log_warn('The checkpoint only has 6 layers, but you are trying to drop ' log_warn(
'all of them or more than all of them. Continuing and ' "The checkpoint only has 6 layers, but you are trying to drop "
'dropping only 5 layers.') "all of them or more than all of them. Continuing and "
"dropping only 5 layers."
)
FLAGS.drop_source_layers = 5 FLAGS.drop_source_layers = 5
dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):] dropped_layers = ["2", "3", "lstm", "5", "6"][
-1 * int(FLAGS.drop_source_layers) :
]
# Initialize all variables needed for DS, but not loaded from ckpt # Initialize all variables needed for DS, but not loaded from ckpt
for v in load_vars: for v in load_vars:
if any(layer in v.op.name for layer in dropped_layers): if any(layer in v.op.name for layer in dropped_layers):
@ -67,16 +76,18 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
load_vars -= init_vars load_vars -= init_vars
for v in sorted(load_vars, key=lambda v: v.op.name): for v in sorted(load_vars, key=lambda v: v.op.name):
log_info('Loading variable from checkpoint: %s' % (v.op.name)) log_info("Loading variable from checkpoint: %s" % (v.op.name))
v.load(ckpt.get_tensor(v.op.name), session=session) v.load(ckpt.get_tensor(v.op.name), session=session)
for v in sorted(init_vars, key=lambda v: v.op.name): for v in sorted(init_vars, key=lambda v: v.op.name):
log_info('Initializing variable: %s' % (v.op.name)) log_info("Initializing variable: %s" % (v.op.name))
session.run(v.initializer) session.run(v.initializer)
def _checkpoint_path_or_none(checkpoint_filename): def _checkpoint_path_or_none(checkpoint_filename):
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename) checkpoint = tfv1.train.get_checkpoint_state(
FLAGS.load_checkpoint_dir, checkpoint_filename
)
if not checkpoint: if not checkpoint:
return None return None
return checkpoint.model_checkpoint_path return checkpoint.model_checkpoint_path
@ -91,61 +102,65 @@ def _initialize_all_variables(session):
def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True): def _load_or_init_impl(session, method_order, allow_drop_layers, allow_lr_init=True):
for method in method_order: for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint' # Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best': if method == "best":
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint') ckpt_path = _checkpoint_path_or_none("best_dev_checkpoint")
if ckpt_path: if ckpt_path:
log_info('Loading best validating checkpoint from {}'.format(ckpt_path)) log_info("Loading best validating checkpoint from {}".format(ckpt_path))
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init) return _load_checkpoint(
log_info('Could not find best validating checkpoint.') session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
)
log_info("Could not find best validating checkpoint.")
# Load most recent checkpoint, saved in checkpoint file 'checkpoint' # Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == 'last': elif method == "last":
ckpt_path = _checkpoint_path_or_none('checkpoint') ckpt_path = _checkpoint_path_or_none("checkpoint")
if ckpt_path: if ckpt_path:
log_info('Loading most recent checkpoint from {}'.format(ckpt_path)) log_info("Loading most recent checkpoint from {}".format(ckpt_path))
return _load_checkpoint(session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init) return _load_checkpoint(
log_info('Could not find most recent checkpoint.') session, ckpt_path, allow_drop_layers, allow_lr_init=allow_lr_init
)
log_info("Could not find most recent checkpoint.")
# Initialize all variables # Initialize all variables
elif method == 'init': elif method == "init":
log_info('Initializing all variables.') log_info("Initializing all variables.")
return _initialize_all_variables(session) return _initialize_all_variables(session)
else: else:
log_error('Unknown initialization method: {}'.format(method)) log_error("Unknown initialization method: {}".format(method))
sys.exit(1) sys.exit(1)
log_error('All initialization methods failed ({}).'.format(method_order)) log_error("All initialization methods failed ({}).".format(method_order))
sys.exit(1) sys.exit(1)
def reload_best_checkpoint(session): def reload_best_checkpoint(session):
_load_or_init_impl(session, ['best'], allow_drop_layers=False, allow_lr_init=False) _load_or_init_impl(session, ["best"], allow_drop_layers=False, allow_lr_init=False)
def load_or_init_graph_for_training(session): def load_or_init_graph_for_training(session):
''' """
Load variables from checkpoint or initialize variables. By default this will Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint, try to load the best validating checkpoint, then try the last checkpoint,
and finally initialize the weights from scratch. This can be overriden with and finally initialize the weights from scratch. This can be overriden with
the `--load_train` flag. See its documentation for more info. the `--load_train` flag. See its documentation for more info.
''' """
if FLAGS.load_train == 'auto': if FLAGS.load_train == "auto":
methods = ['best', 'last', 'init'] methods = ["best", "last", "init"]
else: else:
methods = [FLAGS.load_train] methods = [FLAGS.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True) _load_or_init_impl(session, methods, allow_drop_layers=True)
def load_graph_for_evaluation(session): def load_graph_for_evaluation(session):
''' """
Load variables from checkpoint. Initialization is not allowed. By default Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last this will try to load the best validating checkpoint, then try the last
checkpoint. This can be overriden with the `--load_evaluate` flag. See its checkpoint. This can be overriden with the `--load_evaluate` flag. See its
documentation for more info. documentation for more info.
''' """
if FLAGS.load_evaluate == 'auto': if FLAGS.load_evaluate == "auto":
methods = ['best', 'last'] methods = ["best", "last"]
else: else:
methods = [FLAGS.load_evaluate] methods = [FLAGS.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False) _load_or_init_impl(session, methods, allow_drop_layers=False)

View File

@ -2,18 +2,20 @@ from __future__ import absolute_import, division, print_function
import os import os
import sys import sys
import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict from attrdict import AttrDict
from xdg import BaseDirectory as xdg
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
from xdg import BaseDirectory as xdg
import tensorflow.compat.v1 as tfv1
from .augmentations import NormalizeSampleRate, parse_augmentations
from .flags import FLAGS from .flags import FLAGS
from .gpu import get_available_gpus from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .helpers import parse_file_size from .helpers import parse_file_size
from .augmentations import parse_augmentations, NormalizeSampleRate
from .io import path_exists_remote from .io import path_exists_remote
from .logging import log_error, log_warn
class ConfigSingleton: class ConfigSingleton:
_config = None _config = None
@ -22,11 +24,14 @@ class ConfigSingleton:
if not ConfigSingleton._config: if not ConfigSingleton._config:
raise RuntimeError("Global configuration not yet initialized.") raise RuntimeError("Global configuration not yet initialized.")
if not hasattr(ConfigSingleton._config, name): if not hasattr(ConfigSingleton._config, name):
raise RuntimeError("Configuration option {} not found in config.".format(name)) raise RuntimeError(
"Configuration option {} not found in config.".format(name)
)
return ConfigSingleton._config[name] return ConfigSingleton._config[name]
Config = ConfigSingleton() # pylint: disable=invalid-name Config = ConfigSingleton() # pylint: disable=invalid-name
def initialize_globals(): def initialize_globals():
c = AttrDict() c = AttrDict()
@ -34,16 +39,22 @@ def initialize_globals():
# Augmentations # Augmentations
c.augmentations = parse_augmentations(FLAGS.augment) c.augmentations = parse_augmentations(FLAGS.augment)
if c.augmentations and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0: if c.augmentations and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0:
log_warn('Due to current feature-cache settings the exact same sample augmentations of the first ' log_warn(
'epoch will be repeated on all following epochs. This could lead to unintended over-fitting. ' "Due to current feature-cache settings the exact same sample augmentations of the first "
'You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs.') "epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
)
if FLAGS.normalize_sample_rate: if FLAGS.normalize_sample_rate:
c.augmentations = [NormalizeSampleRate(FLAGS.audio_sample_rate)] + c['augmentations'] c.augmentations = [NormalizeSampleRate(FLAGS.audio_sample_rate)] + c[
"augmentations"
]
# Caching # Caching
if FLAGS.cache_for_epochs == 1: if FLAGS.cache_for_epochs == 1:
log_warn('--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it.') log_warn(
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
)
# Read-buffer # Read-buffer
FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer) FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)
@ -58,26 +69,29 @@ def initialize_globals():
# Set default checkpoint dir # Set default checkpoint dir
if not FLAGS.checkpoint_dir: if not FLAGS.checkpoint_dir:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('stt', 'checkpoints')) FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
if FLAGS.load_train not in ['last', 'best', 'init', 'auto']: if FLAGS.load_train not in ["last", "best", "init", "auto"]:
FLAGS.load_train = 'auto' FLAGS.load_train = "auto"
if FLAGS.load_evaluate not in ['last', 'best', 'auto']: if FLAGS.load_evaluate not in ["last", "best", "auto"]:
FLAGS.load_evaluate = 'auto' FLAGS.load_evaluate = "auto"
# Set default summary dir # Set default summary dir
if not FLAGS.summary_dir: if not FLAGS.summary_dir:
FLAGS.summary_dir = xdg.save_data_path(os.path.join('stt', 'summaries')) FLAGS.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
# Standard session configuration that'll be used for all new sessions. # Standard session configuration that'll be used for all new sessions.
c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, c.session_config = tfv1.ConfigProto(
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, allow_soft_placement=True,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, log_device_placement=FLAGS.log_placement,
gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth)) inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth),
)
# CPU device # CPU device
c.cpu_device = '/cpu:0' c.cpu_device = "/cpu:0"
# Available GPU devices # Available GPU devices
c.available_devices = get_available_gpus(c.session_config) c.available_devices = get_available_gpus(c.session_config)
@ -98,10 +112,10 @@ def initialize_globals():
# doc/Geometry.md # doc/Geometry.md
# Number of MFCC features # Number of MFCC features
c.n_input = 26 # TODO: Determine this programmatically from the sample rate c.n_input = 26 # TODO: Determine this programmatically from the sample rate
# The number of frames in the context # The number of frames in the context
c.n_context = 9 # TODO: Determine the optimal value using a validation data set c.n_context = 9 # TODO: Determine the optimal value using a validation data set
# Number of units in hidden layers # Number of units in hidden layers
c.n_hidden = FLAGS.n_hidden c.n_hidden = FLAGS.n_hidden
@ -119,40 +133,54 @@ def initialize_globals():
c.n_hidden_3 = c.n_cell_dim c.n_hidden_3 = c.n_cell_dim
# Units in the sixth layer = number of characters in the target language plus one # Units in the sixth layer = number of characters in the target language plus one
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
# Size of audio window in samples # Size of audio window in samples
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0: if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
log_error('--feature_win_len value ({}) in milliseconds ({}) multiplied ' log_error(
'by --audio_sample_rate value ({}) must be an integer value. Adjust ' "--feature_win_len value ({}) in milliseconds ({}) multiplied "
'your --feature_win_len value or resample your audio accordingly.' "by --audio_sample_rate value ({}) must be an integer value. Adjust "
''.format(FLAGS.feature_win_len, FLAGS.feature_win_len / 1000, FLAGS.audio_sample_rate)) "your --feature_win_len value or resample your audio accordingly."
"".format(
FLAGS.feature_win_len,
FLAGS.feature_win_len / 1000,
FLAGS.audio_sample_rate,
)
)
sys.exit(1) sys.exit(1)
c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000) c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)
# Stride for feature computations in samples # Stride for feature computations in samples
if (FLAGS.feature_win_step * FLAGS.audio_sample_rate) % 1000 != 0: if (FLAGS.feature_win_step * FLAGS.audio_sample_rate) % 1000 != 0:
log_error('--feature_win_step value ({}) in milliseconds ({}) multiplied ' log_error(
'by --audio_sample_rate value ({}) must be an integer value. Adjust ' "--feature_win_step value ({}) in milliseconds ({}) multiplied "
'your --feature_win_step value or resample your audio accordingly.' "by --audio_sample_rate value ({}) must be an integer value. Adjust "
''.format(FLAGS.feature_win_step, FLAGS.feature_win_step / 1000, FLAGS.audio_sample_rate)) "your --feature_win_step value or resample your audio accordingly."
"".format(
FLAGS.feature_win_step,
FLAGS.feature_win_step / 1000,
FLAGS.audio_sample_rate,
)
)
sys.exit(1) sys.exit(1)
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000) c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
if FLAGS.one_shot_infer: if FLAGS.one_shot_infer:
if not path_exists_remote(FLAGS.one_shot_infer): if not path_exists_remote(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.') log_error("Path specified in --one_shot_infer is not a valid file.")
sys.exit(1) sys.exit(1)
if FLAGS.train_cudnn and FLAGS.load_cudnn: if FLAGS.train_cudnn and FLAGS.load_cudnn:
log_error('Trying to use --train_cudnn, but --load_cudnn ' log_error(
'was also specified. The --load_cudnn flag is only ' "Trying to use --train_cudnn, but --load_cudnn "
'needed when converting a CuDNN RNN checkpoint to ' "was also specified. The --load_cudnn flag is only "
'a CPU-capable graph. If your system is capable of ' "needed when converting a CuDNN RNN checkpoint to "
'using CuDNN RNN, you can just specify the CuDNN RNN ' "a CPU-capable graph. If your system is capable of "
'checkpoint normally with --save_checkpoint_dir.') "using CuDNN RNN, you can just specify the CuDNN RNN "
"checkpoint normally with --save_checkpoint_dir."
)
sys.exit(1) sys.exit(1)
# If separate save and load flags were not specified, default to load and save # If separate save and load flags were not specified, default to load and save
@ -163,4 +191,4 @@ def initialize_globals():
if not FLAGS.load_checkpoint_dir: if not FLAGS.load_checkpoint_dir:
FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir
ConfigSingleton._config = c # pylint: disable=protected-access ConfigSingleton._config = c # pylint: disable=protected-access

View File

@ -1,10 +1,18 @@
import requests from os import makedirs, path
import progressbar import progressbar
import requests
from os import path, makedirs from .io import is_remote_path, open_remote, path_exists_remote
from .io import open_remote, path_exists_remote, is_remote_path
SIMPLE_BAR = [
"Progress ",
progressbar.Bar(),
" ",
progressbar.Percentage(),
" completed",
]
SIMPLE_BAR = ['Progress ', progressbar.Bar(), ' ', progressbar.Percentage(), ' completed']
def maybe_download(archive_name, target_dir, archive_url): def maybe_download(archive_name, target_dir, archive_url):
# If archive file does not exist, download it... # If archive file does not exist, download it...
@ -17,12 +25,15 @@ def maybe_download(archive_name, target_dir, archive_url):
if not path_exists_remote(archive_path): if not path_exists_remote(archive_path):
print('No archive "%s" - downloading...' % archive_path) print('No archive "%s" - downloading...' % archive_path)
req = requests.get(archive_url, stream=True) req = requests.get(archive_url, stream=True)
total_size = int(req.headers.get('content-length', 0)) total_size = int(req.headers.get("content-length", 0))
done = 0 done = 0
with open_remote(archive_path, 'wb') as f: with open_remote(archive_path, "wb") as f:
bar = progressbar.ProgressBar(max_value=total_size if total_size > 0 else progressbar.UnknownLength, widgets=SIMPLE_BAR) bar = progressbar.ProgressBar(
max_value=total_size if total_size > 0 else progressbar.UnknownLength,
widgets=SIMPLE_BAR,
)
for data in req.iter_content(1024*1024): for data in req.iter_content(1024 * 1024):
done += len(data) done += len(data)
f.write(data) f.write(data)
bar.update(done) bar.update(done)

View File

@ -9,8 +9,9 @@ import numpy as np
from attrdict import AttrDict from attrdict import AttrDict
from .flags import FLAGS from .flags import FLAGS
from .text import levenshtein
from .io import open_remote from .io import open_remote
from .text import levenshtein
def pmap(fun, iterable): def pmap(fun, iterable):
pool = Pool() pool = Pool()
@ -42,26 +43,28 @@ def process_decode_result(item):
char_length = len(ground_truth) char_length = len(ground_truth)
word_distance = levenshtein(ground_truth.split(), prediction.split()) word_distance = levenshtein(ground_truth.split(), prediction.split())
word_length = len(ground_truth.split()) word_length = len(ground_truth.split())
return AttrDict({ return AttrDict(
'wav_filename': wav_filename, {
'src': ground_truth, "wav_filename": wav_filename,
'res': prediction, "src": ground_truth,
'loss': loss, "res": prediction,
'char_distance': char_distance, "loss": loss,
'char_length': char_length, "char_distance": char_distance,
'word_distance': word_distance, "char_length": char_length,
'word_length': word_length, "word_distance": word_distance,
'cer': char_distance / char_length, "word_length": word_length,
'wer': word_distance / word_length, "cer": char_distance / char_length,
}) "wer": word_distance / word_length,
}
)
def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name): def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset_name):
r''' r"""
This routine will calculate and print a WER report. This routine will calculate and print a WER report.
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER). loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
''' """
samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses)) samples = pmap(process_decode_result, zip(wav_filenames, labels, decodings, losses))
# Getting the WER and CER from the accumulated edit distances and lengths # Getting the WER and CER from the accumulated edit distances and lengths
@ -88,41 +91,43 @@ def print_report(samples, losses, wer, cer, dataset_name):
# Print summary # Print summary
mean_loss = np.mean(losses) mean_loss = np.mean(losses)
print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset_name, wer, cer, mean_loss)) print(
print('-' * 80) "Test on %s - WER: %f, CER: %f, loss: %f" % (dataset_name, wer, cer, mean_loss)
)
print("-" * 80)
best_samples = samples[:FLAGS.report_count] best_samples = samples[: FLAGS.report_count]
worst_samples = samples[-FLAGS.report_count:] worst_samples = samples[-FLAGS.report_count :]
median_index = int(len(samples) / 2) median_index = int(len(samples) / 2)
median_left = int(FLAGS.report_count / 2) median_left = int(FLAGS.report_count / 2)
median_right = FLAGS.report_count - median_left median_right = FLAGS.report_count - median_left
median_samples = samples[median_index - median_left:median_index + median_right] median_samples = samples[median_index - median_left : median_index + median_right]
def print_single_sample(sample): def print_single_sample(sample):
print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) print("WER: %f, CER: %f, loss: %f" % (sample.wer, sample.cer, sample.loss))
print(' - wav: file://%s' % sample.wav_filename) print(" - wav: file://%s" % sample.wav_filename)
print(' - src: "%s"' % sample.src) print(' - src: "%s"' % sample.src)
print(' - res: "%s"' % sample.res) print(' - res: "%s"' % sample.res)
print('-' * 80) print("-" * 80)
print('Best WER:', '\n' + '-' * 80) print("Best WER:", "\n" + "-" * 80)
for s in best_samples: for s in best_samples:
print_single_sample(s) print_single_sample(s)
print('Median WER:', '\n' + '-' * 80) print("Median WER:", "\n" + "-" * 80)
for s in median_samples: for s in median_samples:
print_single_sample(s) print_single_sample(s)
print('Worst WER:', '\n' + '-' * 80) print("Worst WER:", "\n" + "-" * 80)
for s in worst_samples: for s in worst_samples:
print_single_sample(s) print_single_sample(s)
def save_samples_json(samples, output_path): def save_samples_json(samples, output_path):
''' Save decoded tuples as JSON, converting NumPy floats to Python floats. """Save decoded tuples as JSON, converting NumPy floats to Python floats.
We set ensure_ascii=True to prevent json from escaping non-ASCII chars We set ensure_ascii=True to prevent json from escaping non-ASCII chars
in the texts. in the texts.
''' """
with open_remote(output_path, 'w') as fout: with open_remote(output_path, "w") as fout:
json.dump(samples, fout, default=float, ensure_ascii=False, indent=2) json.dump(samples, fout, default=float, ensure_ascii=False, indent=2)

View File

@ -5,117 +5,175 @@ from collections import Counter
from functools import partial from functools import partial
import numpy as np import numpy as np
import tensorflow as tf
import tensorflow as tf
from tensorflow.python.ops import gen_audio_ops as contrib_audio from tensorflow.python.ops import gen_audio_ops as contrib_audio
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
from .config import Config from .config import Config
from .text import text_to_char_array
from .flags import FLAGS from .flags import FLAGS
from .augmentations import apply_sample_augmentations, apply_graph_augmentations from .helpers import MEGABYTE, remember_exception
from .audio import read_frames_from_file, vad_split, pcm_to_np, DEFAULT_FORMAT
from .sample_collections import samples_from_sources from .sample_collections import samples_from_sources
from .helpers import remember_exception, MEGABYTE from .text import text_to_char_array
def audio_to_features(audio, sample_rate, transcript=None, clock=0.0, train_phase=False, augmentations=None, sample_id=None): def audio_to_features(
audio,
sample_rate,
transcript=None,
clock=0.0,
train_phase=False,
augmentations=None,
sample_id=None,
):
if train_phase: if train_phase:
# We need the lambdas to make TensorFlow happy. # We need the lambdas to make TensorFlow happy.
# pylint: disable=unnecessary-lambda # pylint: disable=unnecessary-lambda
tf.cond(tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate), tf.cond(
lambda: tf.print('WARNING: sample rate of sample', sample_id, '(', sample_rate, ') ' tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate),
'does not match FLAGS.audio_sample_rate. This can lead to incorrect results.'), lambda: tf.print(
lambda: tf.no_op(), "WARNING: sample rate of sample",
name='matching_sample_rate') sample_id,
"(",
sample_rate,
") "
"does not match FLAGS.audio_sample_rate. This can lead to incorrect results.",
),
lambda: tf.no_op(),
name="matching_sample_rate",
)
if train_phase and augmentations: if train_phase and augmentations:
audio = apply_graph_augmentations('signal', audio, augmentations, transcript=transcript, clock=clock) audio = apply_graph_augmentations(
"signal", audio, augmentations, transcript=transcript, clock=clock
)
spectrogram = contrib_audio.audio_spectrogram(audio, spectrogram = contrib_audio.audio_spectrogram(
window_size=Config.audio_window_samples, audio,
stride=Config.audio_step_samples, window_size=Config.audio_window_samples,
magnitude_squared=True) stride=Config.audio_step_samples,
magnitude_squared=True,
)
if train_phase and augmentations: if train_phase and augmentations:
spectrogram = apply_graph_augmentations('spectrogram', spectrogram, augmentations, transcript=transcript, clock=clock) spectrogram = apply_graph_augmentations(
"spectrogram",
spectrogram,
augmentations,
transcript=transcript,
clock=clock,
)
features = contrib_audio.mfcc(spectrogram=spectrogram, features = contrib_audio.mfcc(
sample_rate=sample_rate, spectrogram=spectrogram,
dct_coefficient_count=Config.n_input, sample_rate=sample_rate,
upper_frequency_limit=FLAGS.audio_sample_rate / 2) dct_coefficient_count=Config.n_input,
upper_frequency_limit=FLAGS.audio_sample_rate / 2,
)
features = tf.reshape(features, [-1, Config.n_input]) features = tf.reshape(features, [-1, Config.n_input])
if train_phase and augmentations: if train_phase and augmentations:
features = apply_graph_augmentations('features', features, augmentations, transcript=transcript, clock=clock) features = apply_graph_augmentations(
"features", features, augmentations, transcript=transcript, clock=clock
)
return features, tf.shape(input=features)[0] return features, tf.shape(input=features)[0]
def audiofile_to_features(wav_filename, clock=0.0, train_phase=False, augmentations=None): def audiofile_to_features(
wav_filename, clock=0.0, train_phase=False, augmentations=None
):
samples = tf.io.read_file(wav_filename) samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1) decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return audio_to_features(decoded.audio, return audio_to_features(
decoded.sample_rate, decoded.audio,
clock=clock, decoded.sample_rate,
train_phase=train_phase, clock=clock,
augmentations=augmentations, train_phase=train_phase,
sample_id=wav_filename) augmentations=augmentations,
sample_id=wav_filename,
)
def entry_to_features(sample_id, audio, sample_rate, transcript, clock, train_phase=False, augmentations=None): def entry_to_features(
sample_id,
audio,
sample_rate,
transcript,
clock,
train_phase=False,
augmentations=None,
):
# https://bugs.python.org/issue32117 # https://bugs.python.org/issue32117
sparse_transcript = tf.SparseTensor(*transcript) sparse_transcript = tf.SparseTensor(*transcript)
features, features_len = audio_to_features(audio, features, features_len = audio_to_features(
sample_rate, audio,
transcript=sparse_transcript, sample_rate,
clock=clock, transcript=sparse_transcript,
train_phase=train_phase, clock=clock,
augmentations=augmentations, train_phase=train_phase,
sample_id=sample_id) augmentations=augmentations,
sample_id=sample_id,
)
return sample_id, features, features_len, sparse_transcript return sample_id, features, features_len, sparse_transcript
def to_sparse_tuple(sequence): def to_sparse_tuple(sequence):
r"""Creates a sparse representention of ``sequence``. r"""Creates a sparse representention of ``sequence``.
Returns a tuple with (indices, values, shape) Returns a tuple with (indices, values, shape)
""" """
indices = np.asarray(list(zip([0]*len(sequence), range(len(sequence)))), dtype=np.int64) indices = np.asarray(
list(zip([0] * len(sequence), range(len(sequence)))), dtype=np.int64
)
shape = np.asarray([1, len(sequence)], dtype=np.int64) shape = np.asarray([1, len(sequence)], dtype=np.int64)
return indices, sequence, shape return indices, sequence, shape
def create_dataset(sources, def create_dataset(
batch_size, sources,
epochs=1, batch_size,
augmentations=None, epochs=1,
cache_path=None, augmentations=None,
train_phase=False, cache_path=None,
reverse=False, train_phase=False,
limit=0, reverse=False,
exception_box=None, limit=0,
process_ahead=None, exception_box=None,
buffering=1 * MEGABYTE): process_ahead=None,
buffering=1 * MEGABYTE,
):
epoch_counter = Counter() # survives restarts of the dataset and its generator epoch_counter = Counter() # survives restarts of the dataset and its generator
def generate_values(): def generate_values():
epoch = epoch_counter['epoch'] epoch = epoch_counter["epoch"]
if train_phase: if train_phase:
epoch_counter['epoch'] += 1 epoch_counter["epoch"] += 1
samples = samples_from_sources(sources, buffering=buffering, labeled=True, reverse=reverse) samples = samples_from_sources(
sources, buffering=buffering, labeled=True, reverse=reverse
)
num_samples = len(samples) num_samples = len(samples)
if limit > 0: if limit > 0:
num_samples = min(limit, num_samples) num_samples = min(limit, num_samples)
samples = apply_sample_augmentations(samples, samples = apply_sample_augmentations(
augmentations, samples,
buffering=buffering, augmentations,
process_ahead=2 * batch_size if process_ahead is None else process_ahead, buffering=buffering,
clock=epoch / epochs, process_ahead=2 * batch_size if process_ahead is None else process_ahead,
final_clock=(epoch + 1) / epochs) clock=epoch / epochs,
final_clock=(epoch + 1) / epochs,
)
for sample_index, sample in enumerate(samples): for sample_index, sample in enumerate(samples):
if sample_index >= num_samples: if sample_index >= num_samples:
break break
clock = (epoch * num_samples + sample_index) / (epochs * num_samples) if train_phase and epochs > 0 else 0.0 clock = (
transcript = text_to_char_array(sample.transcript, Config.alphabet, context=sample.sample_id) (epoch * num_samples + sample_index) / (epochs * num_samples)
if train_phase and epochs > 0
else 0.0
)
transcript = text_to_char_array(
sample.transcript, Config.alphabet, context=sample.sample_id
)
transcript = to_sparse_tuple(transcript) transcript = to_sparse_tuple(transcript)
yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript, clock yield sample.sample_id, sample.audio, sample.audio_format.rate, transcript, clock
@ -128,31 +186,46 @@ def create_dataset(sources,
def batch_fn(sample_ids, features, features_len, transcripts): def batch_fn(sample_ids, features, features_len, transcripts):
features = tf.data.Dataset.zip((features, features_len)) features = tf.data.Dataset.zip((features, features_len))
features = features.padded_batch(batch_size, padded_shapes=([None, Config.n_input], [])) features = features.padded_batch(
batch_size, padded_shapes=([None, Config.n_input], [])
)
transcripts = transcripts.batch(batch_size).map(sparse_reshape) transcripts = transcripts.batch(batch_size).map(sparse_reshape)
sample_ids = sample_ids.batch(batch_size) sample_ids = sample_ids.batch(batch_size)
return tf.data.Dataset.zip((sample_ids, features, transcripts)) return tf.data.Dataset.zip((sample_ids, features, transcripts))
process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations) process_fn = partial(
entry_to_features, train_phase=train_phase, augmentations=augmentations
)
dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), dataset = tf.data.Dataset.from_generator(
output_types=(tf.string, tf.float32, tf.int32, remember_exception(generate_values, exception_box),
(tf.int64, tf.int32, tf.int64), tf.float64)) output_types=(
.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) tf.string,
tf.float32,
tf.int32,
(tf.int64, tf.int32, tf.int64),
tf.float64,
),
).map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if cache_path: if cache_path:
dataset = dataset.cache(cache_path) dataset = dataset.cache(cache_path)
dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn) dataset = (
.prefetch(len(Config.available_devices))) dataset.window(batch_size, drop_remainder=train_phase)
.flat_map(batch_fn)
.prefetch(len(Config.available_devices))
)
return dataset return dataset
def split_audio_file(audio_path, def split_audio_file(
audio_format=DEFAULT_FORMAT, audio_path,
batch_size=1, audio_format=DEFAULT_FORMAT,
aggressiveness=3, batch_size=1,
outlier_duration_ms=10000, aggressiveness=3,
outlier_batch_size=1, outlier_duration_ms=10000,
exception_box=None): outlier_batch_size=1,
exception_box=None,
):
def generate_values(): def generate_values():
frames = read_frames_from_file(audio_path) frames = read_frames_from_file(audio_path)
segments = vad_split(frames, aggressiveness=aggressiveness) segments = vad_split(frames, aggressiveness=aggressiveness)
@ -166,17 +239,23 @@ def split_audio_file(audio_path,
return time_start, time_end, features, features_len return time_start, time_end, features, features_len
def create_batch_set(bs, criteria): def create_batch_set(bs, criteria):
return (tf.data.Dataset return (
.from_generator(remember_exception(generate_values, exception_box), tf.data.Dataset.from_generator(
output_types=(tf.int32, tf.int32, tf.float32)) remember_exception(generate_values, exception_box),
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE) output_types=(tf.int32, tf.int32, tf.float32),
.filter(criteria) )
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))) .map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.filter(criteria)
.padded_batch(bs, padded_shapes=([], [], [None, Config.n_input], []))
)
nds = create_batch_set(batch_size, nds = create_batch_set(
lambda start, end, f, fl: end - start <= int(outlier_duration_ms)) batch_size, lambda start, end, f, fl: end - start <= int(outlier_duration_ms)
ods = create_batch_set(outlier_batch_size, )
lambda start, end, f, fl: end - start > int(outlier_duration_ms)) ods = create_batch_set(
outlier_batch_size,
lambda start, end, f, fl: end - start > int(outlier_duration_ms),
)
dataset = nds.concatenate(ods) dataset = nds.concatenate(ods)
dataset = dataset.prefetch(len(Config.available_devices)) dataset = dataset.prefetch(len(Config.available_devices))
return dataset return dataset

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os import os
import absl.flags import absl.flags
FLAGS = absl.flags.FLAGS FLAGS = absl.flags.FLAGS
@ -12,179 +13,448 @@ def create_flags():
f = absl.flags f = absl.flags
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.') f.DEFINE_string(
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run.') "train_files",
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested.') "",
f.DEFINE_string('metrics_files', '', 'comma separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed.') "comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.",
)
f.DEFINE_string(
"dev_files",
"",
"comma separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run.",
)
f.DEFINE_string(
"test_files",
"",
"comma separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested.",
)
f.DEFINE_string(
"metrics_files",
"",
"comma separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed.",
)
f.DEFINE_string('read_buffer', '1MB', 'buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)') f.DEFINE_string(
f.DEFINE_string('feature_cache', '', 'cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.') "read_buffer",
f.DEFINE_integer('cache_for_epochs', 0, 'after how many epochs the feature cache is invalidated again - 0 for "never"') "1MB",
"buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)",
)
f.DEFINE_string(
"feature_cache",
"",
"cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.",
)
f.DEFINE_integer(
"cache_for_epochs",
0,
'after how many epochs the feature cache is invalidated again - 0 for "never"',
)
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds') f.DEFINE_integer(
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds') "feature_win_len", 32, "feature extraction audio window length in milliseconds"
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model') )
f.DEFINE_boolean('normalize_sample_rate', True, 'normalize sample rate of all train_files to --audio_sample_rate') f.DEFINE_integer(
"feature_win_step", 20, "feature extraction window step length in milliseconds"
)
f.DEFINE_integer("audio_sample_rate", 16000, "sample rate value expected by model")
f.DEFINE_boolean(
"normalize_sample_rate",
True,
"normalize sample rate of all train_files to --audio_sample_rate",
)
# Data Augmentation # Data Augmentation
# ================ # ================
f.DEFINE_multi_string('augment', None, 'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"') f.DEFINE_multi_string(
"augment",
None,
'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"',
)
# Global Constants # Global Constants
# ================ # ================
f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for') f.DEFINE_integer(
"epochs",
75,
"how many epochs (complete runs through the train files) to train for",
)
f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers') f.DEFINE_float("dropout_rate", 0.05, "dropout rate for feedforward layers")
f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate') f.DEFINE_float(
f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate') "dropout_rate2", -1.0, "dropout rate for layer 2 - defaults to dropout_rate"
f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0') )
f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0') f.DEFINE_float(
f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate') "dropout_rate3", -1.0, "dropout rate for layer 3 - defaults to dropout_rate"
)
f.DEFINE_float("dropout_rate4", 0.0, "dropout rate for layer 4 - defaults to 0.0")
f.DEFINE_float("dropout_rate5", 0.0, "dropout rate for layer 5 - defaults to 0.0")
f.DEFINE_float(
"dropout_rate6", -1.0, "dropout rate for layer 6 - defaults to dropout_rate"
)
f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers') f.DEFINE_float("relu_clip", 20.0, "ReLU clipping value for non-recurrent layers")
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters # Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer') f.DEFINE_float("beta1", 0.9, "beta 1 parameter of Adam optimizer")
f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer') f.DEFINE_float("beta2", 0.999, "beta 2 parameter of Adam optimizer")
f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer') f.DEFINE_float("epsilon", 1e-8, "epsilon parameter of Adam optimizer")
f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer') f.DEFINE_float("learning_rate", 0.001, "learning rate of Adam optimizer")
# Batch sizes # Batch sizes
f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch') f.DEFINE_integer("train_batch_size", 1, "number of elements in a training batch")
f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch') f.DEFINE_integer("dev_batch_size", 1, "number of elements in a validation batch")
f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch') f.DEFINE_integer("test_batch_size", 1, "number of elements in a test batch")
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph') f.DEFINE_integer(
"export_batch_size", 1, "number of elements per batch on the exported graph"
)
# Performance # Performance
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer(
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') "inter_op_parallelism_threads",
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory') 0,
f.DEFINE_boolean('load_cudnn', False, 'Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.') "number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED",
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') )
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.') f.DEFINE_integer(
"intra_op_parallelism_threads",
0,
"number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED",
)
f.DEFINE_boolean(
"use_allow_growth",
False,
"use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory",
)
f.DEFINE_boolean(
"load_cudnn",
False,
"Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.",
)
f.DEFINE_boolean(
"train_cudnn",
False,
"use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work",
)
f.DEFINE_boolean(
"automatic_mixed_precision",
False,
"whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.",
)
# Sample limits # Sample limits
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') f.DEFINE_integer(
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set - 0 means no limit') "limit_train",
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set - 0 means no limit') 0,
"maximum number of elements to use from train set - 0 means no limit",
)
f.DEFINE_integer(
"limit_dev",
0,
"maximum number of elements to use from validation set - 0 means no limit",
)
f.DEFINE_integer(
"limit_test",
0,
"maximum number of elements to use from test set - 0 means no limit",
)
# Sample order # Sample order
f.DEFINE_boolean('reverse_train', False, 'if to reverse sample order of the train set') f.DEFINE_boolean(
f.DEFINE_boolean('reverse_dev', False, 'if to reverse sample order of the dev set') "reverse_train", False, "if to reverse sample order of the train set"
f.DEFINE_boolean('reverse_test', False, 'if to reverse sample order of the test set') )
f.DEFINE_boolean("reverse_dev", False, "if to reverse sample order of the dev set")
f.DEFINE_boolean(
"reverse_test", False, "if to reverse sample order of the test set"
)
# Checkpointing # Checkpointing
f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') f.DEFINE_string(
f.DEFINE_string('load_checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') "checkpoint_dir",
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') "",
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds') 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification',
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5') )
f.DEFINE_string('load_train', 'auto', 'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.') f.DEFINE_string(
f.DEFINE_string('load_evaluate', 'auto', 'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.') "load_checkpoint_dir",
"",
'directory in which checkpoints are stored - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification',
)
f.DEFINE_string(
"save_checkpoint_dir",
"",
'directory to which checkpoints are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification',
)
f.DEFINE_integer("checkpoint_secs", 600, "checkpoint saving interval in seconds")
f.DEFINE_integer(
"max_to_keep", 5, "number of checkpoint files to keep - default value is 5"
)
f.DEFINE_string(
"load_train",
"auto",
'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.',
)
f.DEFINE_string(
"load_evaluate",
"auto",
'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.',
)
# Transfer Learning # Transfer Learning
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)') f.DEFINE_integer(
"drop_source_layers",
0,
"single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)",
)
# Exporting # Exporting
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported') f.DEFINE_string(
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models') "export_dir",
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine') "",
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') "directory in which exported models are stored - if omitted, the model won't get exported",
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') )
f.DEFINE_string('export_file_name', 'output_graph', 'name for the exported model file name') f.DEFINE_boolean("remove_export", False, "whether to remove old exported models")
f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph') f.DEFINE_boolean("export_tflite", False, "export a graph ready for TF Lite engine")
f.DEFINE_integer(
"n_steps",
16,
"how many timesteps to process at once by the export graph, higher values mean more latency",
)
f.DEFINE_boolean(
"export_zip", False, "export a TFLite model and package with LM and info.json"
)
f.DEFINE_string(
"export_file_name", "output_graph", "name for the exported model file name"
)
f.DEFINE_integer(
"export_beam_width", 500, "default beam width to embed into exported graph"
)
# Model metadata # Model metadata
f.DEFINE_string('export_author_id', 'author', 'author of the exported model. GitHub user or organization name used to uniquely identify the author of this model') f.DEFINE_string(
f.DEFINE_string('export_model_name', 'model', 'name of the exported model. Must not contain forward slashes.') "export_author_id",
f.DEFINE_string('export_model_version', '0.0.1', 'semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with Coqui STT versions') "author",
"author of the exported model. GitHub user or organization name used to uniquely identify the author of this model",
)
f.DEFINE_string(
"export_model_name",
"model",
"name of the exported model. Must not contain forward slashes.",
)
f.DEFINE_string(
"export_model_version",
"0.0.1",
"semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with Coqui STT versions",
)
def str_val_equals_help(name, val_desc): def str_val_equals_help(name, val_desc):
f.DEFINE_string(name, '<{}>'.format(val_desc), val_desc) f.DEFINE_string(name, "<{}>".format(val_desc), val_desc)
str_val_equals_help('export_contact_info', 'public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors') str_val_equals_help(
str_val_equals_help('export_license', 'SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.') "export_contact_info",
str_val_equals_help('export_language', 'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".') "public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors",
str_val_equals_help('export_min_stt_version', 'minimum Coqui STT version (inclusive) the exported model is compatible with') )
str_val_equals_help('export_max_stt_version', 'maximum Coqui STT version (inclusive) the exported model is compatible with') str_val_equals_help(
str_val_equals_help('export_description', 'Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.') "export_license",
"SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.",
)
str_val_equals_help(
"export_language",
'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".',
)
str_val_equals_help(
"export_min_stt_version",
"minimum Coqui STT version (inclusive) the exported model is compatible with",
)
str_val_equals_help(
"export_max_stt_version",
"maximum Coqui STT version (inclusive) the exported model is compatible with",
)
str_val_equals_help(
"export_description",
"Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.",
)
# Reporting # Reporting
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR') f.DEFINE_integer(
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') "log_level",
1,
"log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR",
)
f.DEFINE_boolean(
"show_progressbar",
True,
"Show progress for training, validation and testing processes. Log level should be > 0.",
)
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') f.DEFINE_boolean(
f.DEFINE_integer('report_count', 5, 'number of phrases for each of best WER, median WER and worst WER to print out during a WER report') "log_placement",
False,
"whether to log device placement of the operators to the console",
)
f.DEFINE_integer(
"report_count",
5,
"number of phrases for each of best WER, median WER and worst WER to print out during a WER report",
)
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "stt/summaries" within user\'s data home specified by the XDG Base Directory Specification') f.DEFINE_string(
"summary_dir",
"",
'target directory for TensorBoard summaries - defaults to directory "stt/summaries" within user\'s data home specified by the XDG Base Directory Specification',
)
f.DEFINE_string('test_output_file', '', 'path to a file to save all src/decoded/distance/loss tuples generated during a test epoch') f.DEFINE_string(
"test_output_file",
"",
"path to a file to save all src/decoded/distance/loss tuples generated during a test epoch",
)
# Geometry # Geometry
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers') f.DEFINE_integer("n_hidden", 2048, "layer width to use when initialising layers")
f.DEFINE_boolean('layer_norm', False, 'wether to use layer-normalization after each fully-connected layer (except the last one)') f.DEFINE_boolean(
"layer_norm",
False,
"wether to use layer-normalization after each fully-connected layer (except the last one)",
)
# Initialization # Initialization
f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables') f.DEFINE_integer(
"random_seed", 4568, "default random seed that is used to initialize variables"
)
# Early Stopping # Early Stopping
f.DEFINE_boolean('early_stop', False, 'Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.') f.DEFINE_boolean(
f.DEFINE_integer('es_epochs', 25, 'Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point') "early_stop",
f.DEFINE_float('es_min_delta', 0.05, 'Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau') False,
"Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.",
)
f.DEFINE_integer(
"es_epochs",
25,
"Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point",
)
f.DEFINE_float(
"es_min_delta",
0.05,
"Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau",
)
# Reduce learning rate on plateau # Reduce learning rate on plateau
f.DEFINE_boolean('reduce_lr_on_plateau', False, 'Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.') f.DEFINE_boolean(
f.DEFINE_integer('plateau_epochs', 10, 'Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping') "reduce_lr_on_plateau",
f.DEFINE_float('plateau_reduction', 0.1, 'Multiplicative factor to apply to the current learning rate if a plateau has occurred.') False,
f.DEFINE_boolean('force_initialize_learning_rate', False, 'Force re-initialization of learning rate which was previously reduced.') "Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.",
)
f.DEFINE_integer(
"plateau_epochs",
10,
"Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping",
)
f.DEFINE_float(
"plateau_reduction",
0.1,
"Multiplicative factor to apply to the current learning rate if a plateau has occurred.",
)
f.DEFINE_boolean(
"force_initialize_learning_rate",
False,
"Force re-initialization of learning rate which was previously reduced.",
)
# Decoder # Decoder
f.DEFINE_boolean('bytes_output_mode', False, 'enable Bytes Output Mode mode. When this is used the model outputs UTF-8 byte values directly rather than using an alphabet mapping. The --alphabet_config_path option will be ignored. See the training documentation for more details.') f.DEFINE_boolean(
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.') "bytes_output_mode",
f.DEFINE_string('scorer_path', '', 'path to the external scorer file.') False,
f.DEFINE_alias('scorer', 'scorer_path') "enable Bytes Output Mode mode. When this is used the model outputs UTF-8 byte values directly rather than using an alphabet mapping. The --alphabet_config_path option will be ignored. See the training documentation for more details.",
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') )
f.DEFINE_float('lm_alpha', 0.931289039105002, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') f.DEFINE_string(
f.DEFINE_float('lm_beta', 1.1834137581510284, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') "alphabet_config_path",
f.DEFINE_float('cutoff_prob', 1.0, 'only consider characters until this probability mass is reached. 1.0 = disabled.') "data/alphabet.txt",
f.DEFINE_integer('cutoff_top_n', 300, 'only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.') "path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.",
)
f.DEFINE_string("scorer_path", "", "path to the external scorer file.")
f.DEFINE_alias("scorer", "scorer_path")
f.DEFINE_integer(
"beam_width",
1024,
"beam width used in the CTC decoder when building candidate transcriptions",
)
f.DEFINE_float(
"lm_alpha",
0.931289039105002,
"the alpha hyperparameter of the CTC decoder. Language Model weight.",
)
f.DEFINE_float(
"lm_beta",
1.1834137581510284,
"the beta hyperparameter of the CTC decoder. Word insertion weight.",
)
f.DEFINE_float(
"cutoff_prob",
1.0,
"only consider characters until this probability mass is reached. 1.0 = disabled.",
)
f.DEFINE_integer(
"cutoff_top_n",
300,
"only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.",
)
# Inference mode # Inference mode
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.') f.DEFINE_string(
"one_shot_infer",
"",
"one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.",
)
# Optimizer mode # Optimizer mode
f.DEFINE_float('lm_alpha_max', 5, 'the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.') f.DEFINE_float(
f.DEFINE_float('lm_beta_max', 5, 'the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.') "lm_alpha_max",
f.DEFINE_integer('n_trials', 2400, 'the number of trials to run during hyperparameter optimization.') 5,
"the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.",
)
f.DEFINE_float(
"lm_beta_max",
5,
"the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.",
)
f.DEFINE_integer(
"n_trials",
2400,
"the number of trials to run during hyperparameter optimization.",
)
# Register validators for paths which require a file to be specified # Register validators for paths which require a file to be specified
f.register_validator('alphabet_config_path', f.register_validator(
os.path.isfile, "alphabet_config_path",
message='The file pointed to by --alphabet_config_path must exist and be readable.') os.path.isfile,
message="The file pointed to by --alphabet_config_path must exist and be readable.",
)
f.register_validator(
"one_shot_infer",
lambda value: not value or os.path.isfile(value),
message="The file pointed to by --one_shot_infer must exist and be readable.",
)
f.register_validator('one_shot_infer',
lambda value: not value or os.path.isfile(value),
message='The file pointed to by --one_shot_infer must exist and be readable.')
# sphinx-doc: training_ref_flags_end # sphinx-doc: training_ref_flags_end

View File

@ -6,4 +6,4 @@ def get_available_gpus(config):
Returns the number of GPUs available on this system. Returns the number of GPUs available on this system.
""" """
local_device_protos = device_lib.list_local_devices(session_config=config) local_device_protos = device_lib.list_local_devices(session_config=config)
return [x.name for x in local_device_protos if x.device_type == 'GPU'] return [x.name for x in local_device_protos if x.device_type == "GPU"]

View File

@ -1,21 +1,21 @@
import heapq
import os import os
import random
import sys import sys
import time import time
import heapq
import semver
import random
from multiprocessing import Pool
from collections import namedtuple from collections import namedtuple
from multiprocessing import Pool
import semver
KILO = 1024 KILO = 1024
KILOBYTE = 1 * KILO KILOBYTE = 1 * KILO
MEGABYTE = KILO * KILOBYTE MEGABYTE = KILO * KILOBYTE
GIGABYTE = KILO * MEGABYTE GIGABYTE = KILO * MEGABYTE
TERABYTE = KILO * GIGABYTE TERABYTE = KILO * GIGABYTE
SIZE_PREFIX_LOOKUP = {'k': KILOBYTE, 'm': MEGABYTE, 'g': GIGABYTE, 't': TERABYTE} SIZE_PREFIX_LOOKUP = {"k": KILOBYTE, "m": MEGABYTE, "g": GIGABYTE, "t": TERABYTE}
ValueRange = namedtuple('ValueRange', 'start end r') ValueRange = namedtuple("ValueRange", "start end r")
def parse_file_size(file_size): def parse_file_size(file_size):
@ -23,39 +23,49 @@ def parse_file_size(file_size):
if len(file_size) == 0: if len(file_size) == 0:
return 0 return 0
n = int(keep_only_digits(file_size)) n = int(keep_only_digits(file_size))
if file_size[-1] == 'b': if file_size[-1] == "b":
file_size = file_size[:-1] file_size = file_size[:-1]
e = file_size[-1] e = file_size[-1]
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
def keep_only_digits(txt): def keep_only_digits(txt):
return ''.join(filter(str.isdigit, txt)) return "".join(filter(str.isdigit, txt))
def secs_to_hours(secs): def secs_to_hours(secs):
hours, remainder = divmod(secs, 3600) hours, remainder = divmod(secs, 3600)
minutes, seconds = divmod(remainder, 60) minutes, seconds = divmod(remainder, 60)
return '%d:%02d:%02d' % (hours, minutes, seconds) return "%d:%02d:%02d" % (hours, minutes, seconds)
def check_ctcdecoder_version(): def check_ctcdecoder_version():
ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip() ds_version_s = (
open(os.path.join(os.path.dirname(__file__), "../VERSION")).read().strip()
)
try: try:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from coqui_stt_ctcdecoder import __version__ as decoder_version from coqui_stt_ctcdecoder import __version__ as decoder_version
except ImportError as e: except ImportError as e:
if e.msg.find('__version__') > 0: if e.msg.find("__version__") > 0:
print("Coqui STT version ({ds_version}) requires CTC decoder to expose __version__. " print(
"Please upgrade the coqui_stt_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s)) "Coqui STT version ({ds_version}) requires CTC decoder to expose __version__. "
"Please upgrade the coqui_stt_ctcdecoder package to version {ds_version}".format(
ds_version=ds_version_s
)
)
sys.exit(1) sys.exit(1)
raise e raise e
rv = semver.compare(ds_version_s, decoder_version) rv = semver.compare(ds_version_s, decoder_version)
if rv != 0: if rv != 0:
print("Coqui STT version ({}) and CTC decoder version ({}) do not match. " print(
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version)) "Coqui STT version ({}) and CTC decoder version ({}) do not match. "
"Please ensure matching versions are in use.".format(
ds_version_s, decoder_version
)
)
sys.exit(1) sys.exit(1)
return rv return rv
@ -65,6 +75,7 @@ class Interleaved:
"""Collection that lazily combines sorted collections in an interleaving fashion. """Collection that lazily combines sorted collections in an interleaving fashion.
During iteration the next smallest element from all the sorted collections is always picked. During iteration the next smallest element from all the sorted collections is always picked.
The collections must support iter() and len().""" The collections must support iter() and len()."""
def __init__(self, *iterables, key=lambda obj: obj, reverse=False): def __init__(self, *iterables, key=lambda obj: obj, reverse=False):
self.iterables = iterables self.iterables = iterables
self.key = key self.key = key
@ -83,6 +94,7 @@ class LenMap:
Wrapper around python map() output object that preserves the original collection length Wrapper around python map() output object that preserves the original collection length
by implementing __len__. by implementing __len__.
""" """
def __init__(self, fn, iterable): def __init__(self, fn, iterable):
try: try:
self.length = len(iterable) self.length = len(iterable)
@ -108,11 +120,21 @@ class LimitingPool:
"""Limits unbound ahead-processing of multiprocessing.Pool's imap method """Limits unbound ahead-processing of multiprocessing.Pool's imap method
before items get consumed by the iteration caller. before items get consumed by the iteration caller.
This prevents OOM issues in situations where items represent larger memory allocations.""" This prevents OOM issues in situations where items represent larger memory allocations."""
def __init__(self, processes=None, initializer=None, initargs=None, process_ahead=None, sleeping_for=0.1):
def __init__(
self,
processes=None,
initializer=None,
initargs=None,
process_ahead=None,
sleeping_for=0.1,
):
self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead self.process_ahead = os.cpu_count() if process_ahead is None else process_ahead
self.sleeping_for = sleeping_for self.sleeping_for = sleeping_for
self.processed = 0 self.processed = 0
self.pool = Pool(processes=processes, initializer=initializer, initargs=initargs) self.pool = Pool(
processes=processes, initializer=initializer, initargs=initargs
)
def __enter__(self): def __enter__(self):
return self return self
@ -139,6 +161,7 @@ class LimitingPool:
class ExceptionBox: class ExceptionBox:
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator. """Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
Used in conjunction with `remember_exception`.""" Used in conjunction with `remember_exception`."""
def __init__(self): def __init__(self):
self.exception = None self.exception = None
@ -152,6 +175,7 @@ class ExceptionBox:
def remember_exception(iterable, exception_box=None): def remember_exception(iterable, exception_box=None):
"""Wraps a TensorFlow dataset generator for catching its actual exceptions """Wraps a TensorFlow dataset generator for catching its actual exceptions
that would otherwise just interrupt iteration w/o bubbling up.""" that would otherwise just interrupt iteration w/o bubbling up."""
def do_iterate(): def do_iterate():
try: try:
yield from iterable() yield from iterable()
@ -159,6 +183,7 @@ def remember_exception(iterable, exception_box=None):
return return
except Exception as ex: # pylint: disable = broad-except except Exception as ex: # pylint: disable = broad-except
exception_box.exception = ex exception_box.exception = ex
return iterable if exception_box is None else do_iterate return iterable if exception_box is None else do_iterate
@ -174,30 +199,34 @@ def get_value_range(value, target_type):
Any "missing" values are filled so that ValueRange always includes [start,end,r]. Any "missing" values are filled so that ValueRange always includes [start,end,r].
""" """
if isinstance(value, str): if isinstance(value, str):
if '~' in value: if "~" in value:
parts = value.split('~') parts = value.split("~")
if len(parts) != 2: if len(parts) != 2:
raise ValueError('Cannot parse value range') raise ValueError("Cannot parse value range")
value = parts[0] value = parts[0]
r = parts[1] r = parts[1]
else: else:
r = 0 # if no <r> supplied, use 0 r = 0 # if no <r> supplied, use 0
parts = value.split(':') parts = value.split(":")
if len(parts) == 1: if len(parts) == 1:
parts.append(parts[0]) # only one <value> given, so double it parts.append(parts[0]) # only one <value> given, so double it
if len(parts) != 2: if len(parts) != 2:
raise ValueError('Cannot parse value range') raise ValueError("Cannot parse value range")
return ValueRange(target_type(parts[0]), target_type(parts[1]), target_type(r)) return ValueRange(target_type(parts[0]), target_type(parts[1]), target_type(r))
if isinstance(value, tuple): if isinstance(value, tuple):
if len(value) == 2: if len(value) == 2:
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(0)) return ValueRange(
target_type(value[0]), target_type(value[1]), target_type(0)
)
if len(value) == 3: if len(value) == 3:
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[2])) return ValueRange(
target_type(value[0]), target_type(value[1]), target_type(value[2])
)
else: else:
raise ValueError('Cannot convert to ValueRange: Wrong tuple size') raise ValueError("Cannot convert to ValueRange: Wrong tuple size")
if isinstance(value, int) or isinstance(value, float): if isinstance(value, int) or isinstance(value, float):
return ValueRange(target_type(value), target_type(value), target_type(0)) return ValueRange(target_type(value), target_type(value), target_type(0))
raise ValueError('Cannot convert to ValueRange: Wrong tuple size') raise ValueError("Cannot convert to ValueRange: Wrong tuple size")
def int_range(value): def int_range(value):
@ -217,20 +246,25 @@ def pick_value_from_range(value_range, clock=None):
def tf_pick_value_from_range(value_range, clock=None, double_precision=False): def tf_pick_value_from_range(value_range, clock=None, double_precision=False):
import tensorflow as tf # pylint: disable=import-outside-toplevel import tensorflow as tf # pylint: disable=import-outside-toplevel
if clock is None: if clock is None:
clock = tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64) clock = tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64)
else: else:
clock = tf.maximum(tf.constant(0.0, dtype=tf.float64), clock = tf.maximum(
tf.minimum(tf.constant(1.0, dtype=tf.float64), clock)) tf.constant(0.0, dtype=tf.float64),
tf.minimum(tf.constant(1.0, dtype=tf.float64), clock),
)
value = value_range.start + clock * (value_range.end - value_range.start) value = value_range.start + clock * (value_range.end - value_range.start)
if value_range.r: if value_range.r:
# if the option <r> (<value>~<r>, randomization radius) is supplied, # if the option <r> (<value>~<r>, randomization radius) is supplied,
# sample the value from a uniform distribution with "radius" <r> # sample the value from a uniform distribution with "radius" <r>
value = tf.random.stateless_uniform([], value = tf.random.stateless_uniform(
minval=value - value_range.r, [],
maxval=value + value_range.r, minval=value - value_range.r,
seed=(clock * tf.int32.min, clock * tf.int32.max), maxval=value + value_range.r,
dtype=tf.float64) seed=(clock * tf.int32.min, clock * tf.int32.max),
dtype=tf.float64,
)
if isinstance(value_range.start, int): if isinstance(value_range.start, int):
return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32) return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32)
return tf.cast(value, tf.float64 if double_precision else tf.float32) return tf.cast(value, tf.float64 if double_precision else tf.float32)

View File

@ -3,33 +3,72 @@ import importlib
import os import os
import re import re
import sys import sys
from .helpers import secs_to_hours
from collections import Counter from collections import Counter
from .helpers import secs_to_hours
def get_counter(): def get_counter():
return Counter({'all': 0, 'failed': 0, 'invalid_label': 0, 'too_short': 0, 'too_long': 0, 'imported_time': 0, 'total_time': 0}) return Counter(
{
"all": 0,
"failed": 0,
"invalid_label": 0,
"too_short": 0,
"too_long": 0,
"imported_time": 0,
"total_time": 0,
}
)
def get_imported_samples(counter): def get_imported_samples(counter):
return counter['all'] - counter['failed'] - counter['too_short'] - counter['too_long'] - counter['invalid_label'] return (
counter["all"]
- counter["failed"]
- counter["too_short"]
- counter["too_long"]
- counter["invalid_label"]
)
def print_import_report(counter, sample_rate, max_secs): def print_import_report(counter, sample_rate, max_secs):
print('Imported %d samples.' % (get_imported_samples(counter))) print("Imported %d samples." % (get_imported_samples(counter)))
if counter['failed'] > 0: if counter["failed"] > 0:
print('Skipped %d samples that failed upon conversion.' % counter['failed']) print("Skipped %d samples that failed upon conversion." % counter["failed"])
if counter['invalid_label'] > 0: if counter["invalid_label"] > 0:
print('Skipped %d samples that failed on transcript validation.' % counter['invalid_label']) print(
if counter['too_short'] > 0: "Skipped %d samples that failed on transcript validation."
print('Skipped %d samples that were too short to match the transcript.' % counter['too_short']) % counter["invalid_label"]
if counter['too_long'] > 0: )
print('Skipped %d samples that were longer than %d seconds.' % (counter['too_long'], max_secs)) if counter["too_short"] > 0:
print('Final amount of imported audio: %s from %s.' % (secs_to_hours(counter['imported_time'] / sample_rate), secs_to_hours(counter['total_time'] / sample_rate))) print(
"Skipped %d samples that were too short to match the transcript."
% counter["too_short"]
)
if counter["too_long"] > 0:
print(
"Skipped %d samples that were longer than %d seconds."
% (counter["too_long"], max_secs)
)
print(
"Final amount of imported audio: %s from %s."
% (
secs_to_hours(counter["imported_time"] / sample_rate),
secs_to_hours(counter["total_time"] / sample_rate),
)
)
def get_importers_parser(description): def get_importers_parser(description):
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.') parser.add_argument(
"--validate_label_locale",
help="Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE's DIRECTORY INTO PYTHONPATH.",
)
return parser return parser
def get_validate_label(args): def get_validate_label(args):
""" """
Expects an argparse.Namespace argument to search for validate_label_locale parameter. Expects an argparse.Namespace argument to search for validate_label_locale parameter.
@ -43,19 +82,22 @@ def get_validate_label(args):
:type: function :type: function
""" """
# Python 3.5 does not support passing a pathlib.Path to os.path.* methods # Python 3.5 does not support passing a pathlib.Path to os.path.* methods
if 'validate_label_locale' not in args or (args.validate_label_locale is None): if "validate_label_locale" not in args or (args.validate_label_locale is None):
print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.') print(
"WARNING: No --validate_label_locale specified, your might end with inconsistent dataset."
)
return validate_label_eng return validate_label_eng
validate_label_locale = str(args.validate_label_locale) validate_label_locale = str(args.validate_label_locale)
if not os.path.exists(os.path.abspath(validate_label_locale)): if not os.path.exists(os.path.abspath(validate_label_locale)):
print('ERROR: Inexistent --validate_label_locale specified. Please check.') print("ERROR: Inexistent --validate_label_locale specified. Please check.")
return None return None
module_dir = os.path.abspath(os.path.dirname(validate_label_locale)) module_dir = os.path.abspath(os.path.dirname(validate_label_locale))
sys.path.insert(1, module_dir) sys.path.insert(1, module_dir)
fname = os.path.basename(validate_label_locale).replace('.py', '') fname = os.path.basename(validate_label_locale).replace(".py", "")
locale_module = importlib.import_module(fname, package=None) locale_module = importlib.import_module(fname, package=None)
return locale_module.validate_label return locale_module.validate_label
# Validate and normalize transcriptions. Returns a cleaned version of the label # Validate and normalize transcriptions. Returns a cleaned version of the label
# or None if it's invalid. # or None if it's invalid.
def validate_label_eng(label): def validate_label_eng(label):
@ -72,7 +114,7 @@ def validate_label_eng(label):
label = label.replace("?", "") label = label.replace("?", "")
label = label.replace("!", "") label = label.replace("!", "")
label = label.replace(":", "") label = label.replace(":", "")
label = label.replace("\"", "") label = label.replace('"', "")
label = label.strip() label = label.strip()
label = label.lower() label = label.lower()

View File

@ -4,6 +4,7 @@ into HDFS storage using Tensorflow's C++ FileStream API.
Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets. Currently only includes wrappers for Google's GCS, but this can easily be expanded for AWS S3 buckets.
""" """
import os import os
from tensorflow.io import gfile from tensorflow.io import gfile
@ -12,7 +13,7 @@ def is_remote_path(path):
Returns True iff the path is one of the remote formats that this Returns True iff the path is one of the remote formats that this
module supports module supports
""" """
return path.startswith('gs://') or path.startswith('hdfs://') return path.startswith("gs://") or path.startswith("hdfs://")
def path_exists_remote(path): def path_exists_remote(path):
@ -32,7 +33,9 @@ def copy_remote(src, dst, overwrite=False):
return gfile.copy(src, dst, overwrite) return gfile.copy(src, dst, overwrite)
def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, closefd=True, opener=None): def open_remote(
path, mode="r", buffering=-1, encoding=None, newline=None, closefd=True, opener=None
):
""" """
Wrapper around open() method that can handle remote paths like `gs://...` Wrapper around open() method that can handle remote paths like `gs://...`
off Google Cloud using Tensorflow's IO helpers. off Google Cloud using Tensorflow's IO helpers.
@ -45,7 +48,15 @@ def open_remote(path, mode='r', buffering=-1, encoding=None, newline=None, close
""" """
if is_remote_path(path): if is_remote_path(path):
return gfile.GFile(path, mode=mode) return gfile.GFile(path, mode=mode)
return open(path, mode, buffering=buffering, encoding=encoding, newline=newline, closefd=closefd, opener=opener) return open(
path,
mode,
buffering=buffering,
encoding=encoding,
newline=newline,
closefd=closefd,
opener=opener,
)
def isdir_remote(path): def isdir_remote(path):

View File

@ -1,42 +1,43 @@
from __future__ import print_function from __future__ import print_function
import progressbar
import sys import sys
from .flags import FLAGS import progressbar
from .flags import FLAGS
# Logging functions # Logging functions
# ================= # =================
def prefix_print(prefix, message): def prefix_print(prefix, message):
print(prefix + ('\n' + prefix).join(message.split('\n'))) print(prefix + ("\n" + prefix).join(message.split("\n")))
def log_debug(message): def log_debug(message):
if FLAGS.log_level == 0: if FLAGS.log_level == 0:
prefix_print('D ', message) prefix_print("D ", message)
def log_info(message): def log_info(message):
if FLAGS.log_level <= 1: if FLAGS.log_level <= 1:
prefix_print('I ', message) prefix_print("I ", message)
def log_warn(message): def log_warn(message):
if FLAGS.log_level <= 2: if FLAGS.log_level <= 2:
prefix_print('W ', message) prefix_print("W ", message)
def log_error(message): def log_error(message):
if FLAGS.log_level <= 3: if FLAGS.log_level <= 3:
prefix_print('E ', message) prefix_print("E ", message)
def create_progressbar(*args, **kwargs): def create_progressbar(*args, **kwargs):
# Progress bars in stdout by default # Progress bars in stdout by default
if 'fd' not in kwargs: if "fd" not in kwargs:
kwargs['fd'] = sys.stdout kwargs["fd"] = sys.stdout
if FLAGS.show_progressbar: if FLAGS.show_progressbar:
return progressbar.ProgressBar(*args, **kwargs) return progressbar.ProgressBar(*args, **kwargs)

View File

@ -1,45 +1,47 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
import io
import csv import csv
import io
import json import json
import os
import tarfile import tarfile
from pathlib import Path
from functools import partial from functools import partial
from pathlib import Path
from .helpers import KILOBYTE, MEGABYTE, GIGABYTE, Interleaved, LenMap
from .audio import ( from .audio import (
Sample,
AUDIO_TYPE_PCM,
AUDIO_TYPE_OPUS, AUDIO_TYPE_OPUS,
AUDIO_TYPE_PCM,
SERIALIZABLE_AUDIO_TYPES, SERIALIZABLE_AUDIO_TYPES,
Sample,
get_loadable_audio_type_from_extension, get_loadable_audio_type_from_extension,
write_wav write_wav,
) )
from .io import open_remote, is_remote_path from .helpers import GIGABYTE, KILOBYTE, MEGABYTE, Interleaved, LenMap
from .io import is_remote_path, open_remote
BIG_ENDIAN = 'big' BIG_ENDIAN = "big"
INT_SIZE = 4 INT_SIZE = 4
BIGINT_SIZE = 2 * INT_SIZE BIGINT_SIZE = 2 * INT_SIZE
MAGIC = b'SAMPLEDB' MAGIC = b"SAMPLEDB"
BUFFER_SIZE = 1 * MEGABYTE BUFFER_SIZE = 1 * MEGABYTE
REVERSE_BUFFER_SIZE = 16 * KILOBYTE REVERSE_BUFFER_SIZE = 16 * KILOBYTE
CACHE_SIZE = 1 * GIGABYTE CACHE_SIZE = 1 * GIGABYTE
SCHEMA_KEY = 'schema' SCHEMA_KEY = "schema"
CONTENT_KEY = 'content' CONTENT_KEY = "content"
MIME_TYPE_KEY = 'mime-type' MIME_TYPE_KEY = "mime-type"
MIME_TYPE_TEXT = 'text/plain' MIME_TYPE_TEXT = "text/plain"
CONTENT_TYPE_SPEECH = 'speech' CONTENT_TYPE_SPEECH = "speech"
CONTENT_TYPE_TRANSCRIPT = 'transcript' CONTENT_TYPE_TRANSCRIPT = "transcript"
class LabeledSample(Sample): class LabeledSample(Sample):
"""In-memory labeled audio sample representing an utterance. """In-memory labeled audio sample representing an utterance.
Derived from util.audio.Sample and used by sample collection readers and writers.""" Derived from util.audio.Sample and used by sample collection readers and writers."""
def __init__(self, audio_type, raw_data, transcript, audio_format=None, sample_id=None):
def __init__(
self, audio_type, raw_data, transcript, audio_format=None, sample_id=None
):
""" """
Parameters Parameters
---------- ----------
@ -55,7 +57,9 @@ class LabeledSample(Sample):
Tracking ID - should indicate sample's origin as precisely as possible. Tracking ID - should indicate sample's origin as precisely as possible.
It is typically assigned by collection readers. It is typically assigned by collection readers.
""" """
super().__init__(audio_type, raw_data, audio_format=audio_format, sample_id=sample_id) super().__init__(
audio_type, raw_data, audio_format=audio_format, sample_id=sample_id
)
self.transcript = transcript self.transcript = transcript
@ -65,13 +69,14 @@ class PackedSample:
have the child process do the loading/unpacking of the sample, allowing for parallel file have the child process do the loading/unpacking of the sample, allowing for parallel file
I/O. I/O.
""" """
def __init__(self, filename, audio_type, label): def __init__(self, filename, audio_type, label):
self.filename = filename self.filename = filename
self.audio_type = audio_type self.audio_type = audio_type
self.label = label self.label = label
def unpack(self): def unpack(self):
with open_remote(self.filename, 'rb') as audio_file: with open_remote(self.filename, "rb") as audio_file:
data = audio_file.read() data = audio_file.read()
if self.label is None: if self.label is None:
s = Sample(self.audio_type, data, sample_id=self.filename) s = Sample(self.audio_type, data, sample_id=self.filename)
@ -83,7 +88,7 @@ def unpack_maybe(sample):
""" """
Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already. Loads the supplied sample from disk (or the network) if the audio isn't loaded in to memory already.
""" """
if hasattr(sample, 'unpack'): if hasattr(sample, "unpack"):
realized_sample = sample.unpack() realized_sample = sample.unpack()
else: else:
realized_sample = sample realized_sample = sample
@ -117,13 +122,16 @@ def load_sample(filename, label=None):
class DirectSDBWriter: class DirectSDBWriter:
"""Sample collection writer for creating a Sample DB (SDB) file""" """Sample collection writer for creating a Sample DB (SDB) file"""
def __init__(self,
sdb_filename, def __init__(
buffering=BUFFER_SIZE, self,
audio_type=AUDIO_TYPE_OPUS, sdb_filename,
bitrate=None, buffering=BUFFER_SIZE,
id_prefix=None, audio_type=AUDIO_TYPE_OPUS,
labeled=True): bitrate=None,
id_prefix=None,
labeled=True,
):
""" """
Parameters Parameters
---------- ----------
@ -148,7 +156,7 @@ class DirectSDBWriter:
raise ValueError('Audio type "{}" not supported'.format(audio_type)) raise ValueError('Audio type "{}" not supported'.format(audio_type))
self.audio_type = audio_type self.audio_type = audio_type
self.bitrate = bitrate self.bitrate = bitrate
self.sdb_file = open_remote(sdb_filename, 'wb', buffering=buffering) self.sdb_file = open_remote(sdb_filename, "wb", buffering=buffering)
self.offsets = [] self.offsets = []
self.num_samples = 0 self.num_samples = 0
@ -156,7 +164,9 @@ class DirectSDBWriter:
schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}] schema_entries = [{CONTENT_KEY: CONTENT_TYPE_SPEECH, MIME_TYPE_KEY: audio_type}]
if self.labeled: if self.labeled:
schema_entries.append({CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}) schema_entries.append(
{CONTENT_KEY: CONTENT_TYPE_TRANSCRIPT, MIME_TYPE_KEY: MIME_TYPE_TEXT}
)
meta_data = {SCHEMA_KEY: schema_entries} meta_data = {SCHEMA_KEY: schema_entries}
meta_data = json.dumps(meta_data).encode() meta_data = json.dumps(meta_data).encode()
self.write_big_int(len(meta_data)) self.write_big_int(len(meta_data))
@ -177,20 +187,23 @@ class DirectSDBWriter:
def add(self, sample): def add(self, sample):
def to_bytes(n): def to_bytes(n):
return n.to_bytes(INT_SIZE, BIG_ENDIAN) return n.to_bytes(INT_SIZE, BIG_ENDIAN)
sample.change_audio_type(self.audio_type, bitrate=self.bitrate) sample.change_audio_type(self.audio_type, bitrate=self.bitrate)
opus = sample.audio.getbuffer() opus = sample.audio.getbuffer()
opus_len = to_bytes(len(opus)) opus_len = to_bytes(len(opus))
if self.labeled: if self.labeled:
transcript = sample.transcript.encode() transcript = sample.transcript.encode()
transcript_len = to_bytes(len(transcript)) transcript_len = to_bytes(len(transcript))
entry_len = to_bytes(len(opus_len) + len(opus) + len(transcript_len) + len(transcript)) entry_len = to_bytes(
buffer = b''.join([entry_len, opus_len, opus, transcript_len, transcript]) len(opus_len) + len(opus) + len(transcript_len) + len(transcript)
)
buffer = b"".join([entry_len, opus_len, opus, transcript_len, transcript])
else: else:
entry_len = to_bytes(len(opus_len) + len(opus)) entry_len = to_bytes(len(opus_len) + len(opus))
buffer = b''.join([entry_len, opus_len, opus]) buffer = b"".join([entry_len, opus_len, opus])
self.offsets.append(self.sdb_file.tell()) self.offsets.append(self.sdb_file.tell())
self.sdb_file.write(buffer) self.sdb_file.write(buffer)
sample.sample_id = '{}:{}'.format(self.id_prefix, self.num_samples) sample.sample_id = "{}:{}".format(self.id_prefix, self.num_samples)
self.num_samples += 1 self.num_samples += 1
return sample.sample_id return sample.sample_id
@ -221,12 +234,15 @@ class DirectSDBWriter:
class SDB: # pylint: disable=too-many-instance-attributes class SDB: # pylint: disable=too-many-instance-attributes
"""Sample collection reader for reading a Sample DB (SDB) file""" """Sample collection reader for reading a Sample DB (SDB) file"""
def __init__(self,
sdb_filename, def __init__(
buffering=BUFFER_SIZE, self,
id_prefix=None, sdb_filename,
labeled=True, buffering=BUFFER_SIZE,
reverse=False): id_prefix=None,
labeled=True,
reverse=False,
):
""" """
Parameters Parameters
---------- ----------
@ -244,30 +260,36 @@ class SDB: # pylint: disable=too-many-instance-attributes
""" """
self.sdb_filename = sdb_filename self.sdb_filename = sdb_filename
self.id_prefix = sdb_filename if id_prefix is None else id_prefix self.id_prefix = sdb_filename if id_prefix is None else id_prefix
self.sdb_file = open_remote(sdb_filename, 'rb', buffering=REVERSE_BUFFER_SIZE if reverse else buffering) self.sdb_file = open_remote(
sdb_filename, "rb", buffering=REVERSE_BUFFER_SIZE if reverse else buffering
)
self.offsets = [] self.offsets = []
if self.sdb_file.read(len(MAGIC)) != MAGIC: if self.sdb_file.read(len(MAGIC)) != MAGIC:
raise RuntimeError('No Sample Database') raise RuntimeError("No Sample Database")
meta_chunk_len = self.read_big_int() meta_chunk_len = self.read_big_int()
self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode()) self.meta = json.loads(self.sdb_file.read(meta_chunk_len).decode())
if SCHEMA_KEY not in self.meta: if SCHEMA_KEY not in self.meta:
raise RuntimeError('Missing schema') raise RuntimeError("Missing schema")
self.schema = self.meta[SCHEMA_KEY] self.schema = self.meta[SCHEMA_KEY]
speech_columns = self.find_columns(content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES) speech_columns = self.find_columns(
content=CONTENT_TYPE_SPEECH, mime_type=SERIALIZABLE_AUDIO_TYPES
)
if not speech_columns: if not speech_columns:
raise RuntimeError('No speech data (missing in schema)') raise RuntimeError("No speech data (missing in schema)")
self.speech_index = speech_columns[0] self.speech_index = speech_columns[0]
self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY] self.audio_type = self.schema[self.speech_index][MIME_TYPE_KEY]
self.transcript_index = None self.transcript_index = None
if labeled is not False: if labeled is not False:
transcript_columns = self.find_columns(content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT) transcript_columns = self.find_columns(
content=CONTENT_TYPE_TRANSCRIPT, mime_type=MIME_TYPE_TEXT
)
if transcript_columns: if transcript_columns:
self.transcript_index = transcript_columns[0] self.transcript_index = transcript_columns[0]
else: else:
if labeled is True: if labeled is True:
raise RuntimeError('No transcript data (missing in schema)') raise RuntimeError("No transcript data (missing in schema)")
sample_chunk_len = self.read_big_int() sample_chunk_len = self.read_big_int()
self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1) self.sdb_file.seek(sample_chunk_len + BIGINT_SIZE, 1)
@ -290,12 +312,16 @@ class SDB: # pylint: disable=too-many-instance-attributes
if mime_type is not None: if mime_type is not None:
criteria.append((MIME_TYPE_KEY, mime_type)) criteria.append((MIME_TYPE_KEY, mime_type))
if len(criteria) == 0: if len(criteria) == 0:
raise ValueError('At least one of "content" or "mime-type" has to be provided') raise ValueError(
'At least one of "content" or "mime-type" has to be provided'
)
matches = [] matches = []
for index, column in enumerate(self.schema): for index, column in enumerate(self.schema):
matched = 0 matched = 0
for field, value in criteria: for field, value in criteria:
if column[field] == value or (isinstance(value, list) and column[field] in value): if column[field] == value or (
isinstance(value, list) and column[field] in value
):
matched += 1 matched += 1
if matched == len(criteria): if matched == len(criteria):
matches.append(index) matches.append(index)
@ -306,8 +332,11 @@ class SDB: # pylint: disable=too-many-instance-attributes
column_data = [None] * len(columns) column_data = [None] * len(columns)
found = 0 found = 0
if not 0 <= row_index < len(self.offsets): if not 0 <= row_index < len(self.offsets):
raise ValueError('Wrong sample index: {} - has to be between 0 and {}' raise ValueError(
.format(row_index, len(self.offsets) - 1)) "Wrong sample index: {} - has to be between 0 and {}".format(
row_index, len(self.offsets) - 1
)
)
self.sdb_file.seek(self.offsets[row_index] + INT_SIZE) self.sdb_file.seek(self.offsets[row_index] + INT_SIZE)
for index in range(len(self.schema)): for index in range(len(self.schema)):
chunk_len = self.read_int() chunk_len = self.read_int()
@ -321,13 +350,17 @@ class SDB: # pylint: disable=too-many-instance-attributes
return tuple(column_data) return tuple(column_data)
def __getitem__(self, i): def __getitem__(self, i):
sample_id = '{}:{}'.format(self.id_prefix, i) sample_id = "{}:{}".format(self.id_prefix, i)
if self.transcript_index is None: if self.transcript_index is None:
[audio_data] = self.read_row(i, self.speech_index) [audio_data] = self.read_row(i, self.speech_index)
return Sample(self.audio_type, audio_data, sample_id=sample_id) return Sample(self.audio_type, audio_data, sample_id=sample_id)
audio_data, transcript = self.read_row(i, self.speech_index, self.transcript_index) audio_data, transcript = self.read_row(
i, self.speech_index, self.transcript_index
)
transcript = transcript.decode() transcript = transcript.decode()
return LabeledSample(self.audio_type, audio_data, transcript, sample_id=sample_id) return LabeledSample(
self.audio_type, audio_data, transcript, sample_id=sample_id
)
def __iter__(self): def __iter__(self):
for i in range(len(self.offsets)): for i in range(len(self.offsets)):
@ -346,10 +379,8 @@ class SDB: # pylint: disable=too-many-instance-attributes
class CSVWriter: # pylint: disable=too-many-instance-attributes class CSVWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples""" """Sample collection writer for writing a CSV data-set and all its referenced WAV samples"""
def __init__(self,
csv_filename, def __init__(self, csv_filename, absolute_paths=False, labeled=True):
absolute_paths=False,
labeled=True):
""" """
Parameters Parameters
---------- ----------
@ -372,11 +403,11 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
raise RuntimeError('"{}" already existing'.format(self.csv_dir)) raise RuntimeError('"{}" already existing'.format(self.csv_dir))
os.mkdir(str(self.csv_dir)) os.mkdir(str(self.csv_dir))
self.absolute_paths = absolute_paths self.absolute_paths = absolute_paths
fieldnames = ['wav_filename', 'wav_filesize'] fieldnames = ["wav_filename", "wav_filesize"]
self.labeled = labeled self.labeled = labeled
if labeled: if labeled:
fieldnames.append('transcript') fieldnames.append("transcript")
self.csv_file = open_remote(csv_filename, 'w', encoding='utf-8', newline='') self.csv_file = open_remote(csv_filename, "w", encoding="utf-8", newline="")
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader() self.csv_writer.writeheader()
self.counter = 0 self.counter = 0
@ -385,17 +416,19 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
return self return self
def add(self, sample): def add(self, sample):
sample_filename = self.csv_dir / 'sample{0:08d}.wav'.format(self.counter) sample_filename = self.csv_dir / "sample{0:08d}.wav".format(self.counter)
self.counter += 1 self.counter += 1
sample.change_audio_type(AUDIO_TYPE_PCM) sample.change_audio_type(AUDIO_TYPE_PCM)
write_wav(str(sample_filename), sample.audio, audio_format=sample.audio_format) write_wav(str(sample_filename), sample.audio, audio_format=sample.audio_format)
sample.sample_id = str(sample_filename.relative_to(self.csv_base_dir)) sample.sample_id = str(sample_filename.relative_to(self.csv_base_dir))
row = { row = {
'wav_filename': str(sample_filename.absolute()) if self.absolute_paths else sample.sample_id, "wav_filename": str(sample_filename.absolute())
'wav_filesize': sample_filename.stat().st_size if self.absolute_paths
else sample.sample_id,
"wav_filesize": sample_filename.stat().st_size,
} }
if self.labeled: if self.labeled:
row['transcript'] = sample.transcript row["transcript"] = sample.transcript
self.csv_writer.writerow(row) self.csv_writer.writerow(row)
return sample.sample_id return sample.sample_id
@ -412,11 +445,8 @@ class CSVWriter: # pylint: disable=too-many-instance-attributes
class TarWriter: # pylint: disable=too-many-instance-attributes class TarWriter: # pylint: disable=too-many-instance-attributes
"""Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file.""" """Sample collection writer for writing a CSV data-set and all its referenced WAV samples to a tar file."""
def __init__(self,
tar_filename, def __init__(self, tar_filename, gz=False, labeled=True, include=None):
gz=False,
labeled=True,
include=None):
""" """
Parameters Parameters
---------- ----------
@ -432,17 +462,19 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
Currently only works with local files (not gs:// or hdfs://...) Currently only works with local files (not gs:// or hdfs://...)
""" """
self.tar = tarfile.open(tar_filename, 'w:gz' if gz else 'w') self.tar = tarfile.open(tar_filename, "w:gz" if gz else "w")
samples_dir = tarfile.TarInfo('samples') samples_dir = tarfile.TarInfo("samples")
samples_dir.type = tarfile.DIRTYPE samples_dir.type = tarfile.DIRTYPE
self.tar.addfile(samples_dir) self.tar.addfile(samples_dir)
if include: if include:
for include_path in include: for include_path in include:
self.tar.add(include_path, recursive=False, arcname=Path(include_path).name) self.tar.add(
fieldnames = ['wav_filename', 'wav_filesize'] include_path, recursive=False, arcname=Path(include_path).name
)
fieldnames = ["wav_filename", "wav_filesize"]
self.labeled = labeled self.labeled = labeled
if labeled: if labeled:
fieldnames.append('transcript') fieldnames.append("transcript")
self.csv_file = io.StringIO() self.csv_file = io.StringIO()
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.csv_writer.writeheader() self.csv_writer.writeheader()
@ -452,7 +484,7 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
return self return self
def add(self, sample): def add(self, sample):
sample_filename = 'samples/sample{0:08d}.wav'.format(self.counter) sample_filename = "samples/sample{0:08d}.wav".format(self.counter)
self.counter += 1 self.counter += 1
sample.change_audio_type(AUDIO_TYPE_PCM) sample.change_audio_type(AUDIO_TYPE_PCM)
sample_file = io.BytesIO() sample_file = io.BytesIO()
@ -462,21 +494,18 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
sample_tar = tarfile.TarInfo(sample_filename) sample_tar = tarfile.TarInfo(sample_filename)
sample_tar.size = sample_size sample_tar.size = sample_size
self.tar.addfile(sample_tar, sample_file) self.tar.addfile(sample_tar, sample_file)
row = { row = {"wav_filename": sample_filename, "wav_filesize": sample_size}
'wav_filename': sample_filename,
'wav_filesize': sample_size
}
if self.labeled: if self.labeled:
row['transcript'] = sample.transcript row["transcript"] = sample.transcript
self.csv_writer.writerow(row) self.csv_writer.writerow(row)
return sample_filename return sample_filename
def close(self): def close(self):
if self.csv_file and self.tar: if self.csv_file and self.tar:
csv_tar = tarfile.TarInfo('samples.csv') csv_tar = tarfile.TarInfo("samples.csv")
csv_tar.size = self.csv_file.tell() csv_tar.size = self.csv_file.tell()
self.csv_file.seek(0) self.csv_file.seek(0)
self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode('utf8'))) self.tar.addfile(csv_tar, io.BytesIO(self.csv_file.read().encode("utf8")))
if self.tar: if self.tar:
self.tar.close() self.tar.close()
@ -489,6 +518,7 @@ class TarWriter: # pylint: disable=too-many-instance-attributes
class SampleList: class SampleList:
"""Sample collection base class with samples loaded from a list of in-memory paths.""" """Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True, reverse=False): def __init__(self, samples, labeled=True, reverse=False):
""" """
Parameters Parameters
@ -507,7 +537,9 @@ class SampleList:
def __getitem__(self, i): def __getitem__(self, i):
sample_spec = self.samples[i] sample_spec = self.samples[i]
return load_sample(sample_spec[0], label=sample_spec[2] if self.labeled else None) return load_sample(
sample_spec[0], label=sample_spec[2] if self.labeled else None
)
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
@ -516,6 +548,7 @@ class SampleList:
class CSV(SampleList): class CSV(SampleList):
"""Sample collection reader for reading a Coqui STT CSV file """Sample collection reader for reading a Coqui STT CSV file
Automatically orders samples by CSV column wav_filesize (if available).""" Automatically orders samples by CSV column wav_filesize (if available)."""
def __init__(self, csv_filename, labeled=None, reverse=False): def __init__(self, csv_filename, labeled=None, reverse=False):
""" """
Parameters Parameters
@ -531,30 +564,34 @@ class CSV(SampleList):
If the order of the samples should be reversed If the order of the samples should be reversed
""" """
rows = [] rows = []
with open_remote(csv_filename, 'r', encoding='utf8') as csv_file: with open_remote(csv_filename, "r", encoding="utf8") as csv_file:
reader = csv.DictReader(csv_file) reader = csv.DictReader(csv_file)
if 'transcript' in reader.fieldnames: if "transcript" in reader.fieldnames:
if labeled is None: if labeled is None:
labeled = True labeled = True
elif labeled: elif labeled:
raise RuntimeError('No transcript data (missing CSV column)') raise RuntimeError("No transcript data (missing CSV column)")
for row in reader: for row in reader:
wav_filename = Path(row['wav_filename']) wav_filename = Path(row["wav_filename"])
if not wav_filename.is_absolute() and not is_remote_path(row['wav_filename']): if not wav_filename.is_absolute() and not is_remote_path(
row["wav_filename"]
):
wav_filename = Path(csv_filename).parent / wav_filename wav_filename = Path(csv_filename).parent / wav_filename
wav_filename = str(wav_filename) wav_filename = str(wav_filename)
else: else:
# Pathlib otherwise removes a / from filenames like hdfs:// # Pathlib otherwise removes a / from filenames like hdfs://
wav_filename = row['wav_filename'] wav_filename = row["wav_filename"]
wav_filesize = int(row['wav_filesize']) if 'wav_filesize' in row else 0 wav_filesize = int(row["wav_filesize"]) if "wav_filesize" in row else 0
if labeled: if labeled:
rows.append((wav_filename, wav_filesize, row['transcript'])) rows.append((wav_filename, wav_filesize, row["transcript"]))
else: else:
rows.append((wav_filename, wav_filesize)) rows.append((wav_filename, wav_filesize))
super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse) super(CSV, self).__init__(rows, labeled=labeled, reverse=reverse)
def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False): def samples_from_source(
sample_source, buffering=BUFFER_SIZE, labeled=None, reverse=False
):
""" """
Loads samples from a sample source file. Loads samples from a sample source file.
@ -577,14 +614,16 @@ def samples_from_source(sample_source, buffering=BUFFER_SIZE, labeled=None, reve
iterable of util.sample_collections.LabeledSample or util.audio.Sample instances supporting len. iterable of util.sample_collections.LabeledSample or util.audio.Sample instances supporting len.
""" """
ext = os.path.splitext(sample_source)[1].lower() ext = os.path.splitext(sample_source)[1].lower()
if ext == '.sdb': if ext == ".sdb":
return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse) return SDB(sample_source, buffering=buffering, labeled=labeled, reverse=reverse)
if ext == '.csv': if ext == ".csv":
return CSV(sample_source, labeled=labeled, reverse=reverse) return CSV(sample_source, labeled=labeled, reverse=reverse)
raise ValueError('Unknown file type: "{}"'.format(ext)) raise ValueError('Unknown file type: "{}"'.format(ext))
def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False): def samples_from_sources(
sample_sources, buffering=BUFFER_SIZE, labeled=None, reverse=False
):
""" """
Loads and combines samples from a list of source files. Sources are combined in an interleaving way to Loads and combines samples from a list of source files. Sources are combined in an interleaving way to
keep default sample order from shortest to longest. keep default sample order from shortest to longest.
@ -616,14 +655,22 @@ def samples_from_sources(sample_sources, buffering=BUFFER_SIZE, labeled=None, re
""" """
sample_sources = list(sample_sources) sample_sources = list(sample_sources)
if len(sample_sources) == 0: if len(sample_sources) == 0:
raise ValueError('No files') raise ValueError("No files")
if len(sample_sources) == 1: if len(sample_sources) == 1:
return samples_from_source(sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse) return samples_from_source(
sample_sources[0], buffering=buffering, labeled=labeled, reverse=reverse
)
# If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should # If we wish to interleave based on duration, we have to unpack the audio. Note that this unpacking should
# be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code. # be done lazily onn the fly so that it respects the LimitingPool logic used in the feeding code.
cols = [LenMap( cols = [
unpack_maybe, samples_from_source(source, buffering=buffering, labeled=labeled, reverse=reverse)) LenMap(
for source in sample_sources] unpack_maybe,
samples_from_source(
source, buffering=buffering, labeled=labeled, reverse=reverse
),
)
for source in sample_sources
]
return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse) return Interleaved(*cols, key=lambda s: s.duration, reverse=reverse)

View File

@ -1,27 +1,31 @@
import codecs import codecs
import unicodedata import unicodedata
class STMSegment(object): class STMSegment(object):
r""" r"""
Representation of an individual segment in an STM file. Representation of an individual segment in an STM file.
""" """
def __init__(self, stm_line): def __init__(self, stm_line):
tokens = stm_line.split() tokens = stm_line.split()
self._filename = tokens[0] self._filename = tokens[0]
self._channel = tokens[1] self._channel = tokens[1]
self._speaker_id = tokens[2] self._speaker_id = tokens[2]
self._start_time = float(tokens[3]) self._start_time = float(tokens[3])
self._stop_time = float(tokens[4]) self._stop_time = float(tokens[4])
self._labels = tokens[5] self._labels = tokens[5]
self._transcript = "" self._transcript = ""
for token in tokens[6:]: for token in tokens[6:]:
self._transcript += token + " " self._transcript += token + " "
# We need to do the encode-decode dance here because encode # We need to do the encode-decode dance here because encode
# returns a bytes() object on Python 3, and text_to_char_array # returns a bytes() object on Python 3, and text_to_char_array
# expects a string. # expects a string.
self._transcript = unicodedata.normalize("NFKD", self._transcript.strip()) \ self._transcript = (
.encode("ascii", "ignore") \ unicodedata.normalize("NFKD", self._transcript.strip())
.decode("ascii", "ignore") .encode("ascii", "ignore")
.decode("ascii", "ignore")
)
@property @property
def filename(self): def filename(self):
@ -51,6 +55,7 @@ class STMSegment(object):
def transcript(self): def transcript(self):
return self._transcript return self._transcript
def parse_stm_file(stm_file): def parse_stm_file(stm_file):
r""" r"""
Parses an STM file at ``stm_file`` into a list of :class:`STMSegment`. Parses an STM file at ``stm_file`` into a list of :class:`STMSegment`.

View File

@ -1,9 +1,11 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import numpy as np
import struct import struct
def text_to_char_array(transcript, alphabet, context=''): import numpy as np
def text_to_char_array(transcript, alphabet, context=""):
r""" r"""
Given a transcript string, map characters to Given a transcript string, map characters to
integers and return a numpy array representing the processed string. integers and return a numpy array representing the processed string.
@ -13,15 +15,20 @@ def text_to_char_array(transcript, alphabet, context=''):
# Provide the row context (especially wav_filename) for alphabet errors # Provide the row context (especially wav_filename) for alphabet errors
raise ValueError( raise ValueError(
'Alphabet cannot encode transcript "{}" while processing sample "{}", ' 'Alphabet cannot encode transcript "{}" while processing sample "{}", '
'check that your alphabet contains all characters in the training corpus. ' "check that your alphabet contains all characters in the training corpus. "
'Missing characters are: {}.' "Missing characters are: {}.".format(
.format(transcript, context, list(ch for ch in transcript if not alphabet.CanEncodeSingle(ch)))) transcript,
context,
list(ch for ch in transcript if not alphabet.CanEncodeSingle(ch)),
)
)
encoded = alphabet.Encode(transcript) encoded = alphabet.Encode(transcript)
if len(encoded) == 0: if len(encoded) == 0:
raise ValueError('While processing {}: Found an empty transcript! ' raise ValueError(
'You must include a transcript for all training data.' "While processing {}: Found an empty transcript! "
.format(context)) "You must include a transcript for all training data.".format(context)
)
return encoded return encoded
@ -35,6 +42,7 @@ def text_to_char_array(transcript, alphabet, context=''):
# version 1.0. This software is distributed without any warranty. For more # version 1.0. This software is distributed without any warranty. For more
# information, see <http://creativecommons.org/publicdomain/zero/1.0> # information, see <http://creativecommons.org/publicdomain/zero/1.0>
def levenshtein(a, b): def levenshtein(a, b):
"Calculates the Levenshtein distance between a and b." "Calculates the Levenshtein distance between a and b."
n, m = len(a), len(b) n, m = len(a), len(b)
@ -43,13 +51,13 @@ def levenshtein(a, b):
a, b = b, a a, b = b, a
n, m = m, n n, m = m, n
current = list(range(n+1)) current = list(range(n + 1))
for i in range(1, m+1): for i in range(1, m + 1):
previous, current = current, [i]+[0]*n previous, current = current, [i] + [0] * n
for j in range(1, n+1): for j in range(1, n + 1):
add, delete = previous[j]+1, current[j-1]+1 add, delete = previous[j] + 1, current[j - 1] + 1
change = previous[j-1] change = previous[j - 1]
if a[j-1] != b[i-1]: if a[j - 1] != b[i - 1]:
change = change + 1 change = change + 1
current[j] = min(add, delete, change) current[j] = min(add, delete, change)

View File

@ -2,24 +2,32 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import json
import os import os
import sys import sys
import json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1.logging as tflogging import tensorflow.compat.v1.logging as tflogging
tflogging.set_verbosity(tflogging.ERROR) tflogging.set_verbosity(tflogging.ERROR)
import logging import logging
logging.getLogger('sox').setLevel(logging.ERROR)
import glob
logging.getLogger("sox").setLevel(logging.ERROR)
import glob
from multiprocessing import Process, cpu_count
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.util.audio import AudioFile from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.config import Config, initialize_globals from coqui_stt_training.util.config import Config, initialize_globals
from coqui_stt_training.util.feeding import split_audio_file from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.flags import create_flags, FLAGS from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import log_error, log_info, log_progress, create_progressbar from coqui_stt_training.util.logging import (
from coqui_stt_ctcdecoder import ctc_beam_search_decoder_batch, Scorer create_progressbar,
from multiprocessing import Process, cpu_count log_error,
log_info,
log_progress,
)
def fail(message, code=1): def fail(message, code=1):
@ -28,8 +36,11 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path): def transcribe_file(audio_path, tlog_path):
from coqui_stt_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
create_model,
)
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
initialize_globals() initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try: try:
@ -37,16 +48,23 @@ def transcribe_file(audio_path, tlog_path):
except NotImplementedError: except NotImplementedError:
num_processes = 1 num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path: with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(wav_path, data_set = split_audio_file(
batch_size=FLAGS.batch_size, wav_path,
aggressiveness=FLAGS.vad_aggressiveness, batch_size=FLAGS.batch_size,
outlier_duration_ms=FLAGS.outlier_duration_ms, aggressiveness=FLAGS.vad_aggressiveness,
outlier_batch_size=FLAGS.outlier_batch_size) outlier_duration_ms=FLAGS.outlier_duration_ms,
iterator = tf.data.Iterator.from_structure(data_set.output_types, data_set.output_shapes, outlier_batch_size=FLAGS.outlier_batch_size,
output_classes=data_set.output_classes) )
iterator = tf.data.Iterator.from_structure(
data_set.output_types,
data_set.output_shapes,
output_classes=data_set.output_classes,
)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next() batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6 no_dropout = [None] * 6
logits, _ = create_model(batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout) logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step() tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session: with tf.Session(config=Config.session_config) as session:
@ -55,30 +73,43 @@ def transcribe_file(audio_path, tlog_path):
transcripts = [] transcripts = []
while True: while True:
try: try:
starts, ends, batch_logits, batch_lengths = \ starts, ends, batch_logits, batch_lengths = session.run(
session.run([batch_time_start, batch_time_end, transposed, batch_x_len]) [batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError: except tf.errors.OutOfRangeError:
break break
decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, decoded = ctc_beam_search_decoder_batch(
num_processes=num_processes, batch_logits,
scorer=scorer) batch_lengths,
Config.alphabet,
FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded) decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded)) transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0]) transcripts.sort(key=lambda t: t[0])
transcripts = [{'start': int(start), transcripts = [
'end': int(end), {"start": int(start), "end": int(end), "transcript": transcript}
'transcript': transcript} for start, end, transcript in transcripts] for start, end, transcript in transcripts
with open(tlog_path, 'w') as tlog_file: ]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float) json.dump(transcripts, tlog_file, default=float)
def transcribe_many(src_paths,dst_paths): def transcribe_many(src_paths, dst_paths):
pbar = create_progressbar(prefix='Transcribing files | ', max_value=len(src_paths)).start() pbar = create_progressbar(
prefix="Transcribing files | ", max_value=len(src_paths)
).start()
for i in range(len(src_paths)): for i in range(len(src_paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i])) p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start() p.start()
p.join() p.join()
log_progress('Transcribed file {} of {} from "{}" to "{}"'.format(i + 1, len(src_paths), src_paths[i], dst_paths[i])) log_progress(
'Transcribed file {} of {} from "{}" to "{}"'.format(
i + 1, len(src_paths), src_paths[i], dst_paths[i]
)
)
pbar.update(i) pbar.update(i)
pbar.finish() pbar.finish()
@ -99,70 +130,116 @@ def resolve(base_path, spec_path):
def main(_): def main(_):
if not FLAGS.src or not os.path.exists(FLAGS.src): if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant # path not given or non-existant
fail('You have to specify which file or catalog to transcribe via the --src flag.') fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else: else:
# path given and exists # path given and exists
src_path = os.path.abspath(FLAGS.src) src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path): if os.path.isfile(src_path):
if src_path.endswith('.catalog'): if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign) # Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst: if FLAGS.dst:
fail('Parameter --dst not supported if --src points to a catalog') fail("Parameter --dst not supported if --src points to a catalog")
catalog_dir = os.path.dirname(src_path) catalog_dir = os.path.dirname(src_path)
with open(src_path, 'r') as catalog_file: with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file) catalog_entries = json.load(catalog_file)
catalog_entries = [(resolve(catalog_dir, e['audio']), resolve(catalog_dir, e['tlog'])) for e in catalog_entries] catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)): if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail('Missing source file(s) in catalog') fail("Missing source file(s) in catalog")
if not FLAGS.force and any(map(lambda e: os.path.isfile(e[1]), catalog_entries)): if not FLAGS.force and any(
fail('Destination file(s) from catalog already existing, use --force for overwriting') map(lambda e: os.path.isfile(e[1]), catalog_entries)
if any(map(lambda e: not os.path.isdir(os.path.dirname(e[1])), catalog_entries)): ):
fail('Missing destination directory for at least one catalog entry') fail(
src_paths,dst_paths = zip(*paths) "Destination file(s) from catalog already existing, use --force for overwriting"
transcribe_many(src_paths,dst_paths) )
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else: else:
# Transcribe one file # Transcribe one file
dst_path = os.path.abspath(FLAGS.dst) if FLAGS.dst else os.path.splitext(src_path)[0] + '.tlog' dst_path = (
os.path.abspath(FLAGS.dst)
if FLAGS.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path): if os.path.isfile(dst_path):
if FLAGS.force: if FLAGS.force:
transcribe_one(src_path, dst_path) transcribe_one(src_path, dst_path)
else: else:
fail('Destination file "{}" already existing - use --force for overwriting'.format(dst_path), code=0) fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)): elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path) transcribe_one(src_path, dst_path)
else: else:
fail('Missing destination directory') fail("Missing destination directory")
elif os.path.isdir(src_path): elif os.path.isdir(src_path):
# Transcribe all files in dir # Transcribe all files in dir
print("Transcribing all WAV files in --src") print("Transcribing all WAV files in --src")
if FLAGS.dst: if FLAGS.dst:
fail('Destination file not supported for batch decoding jobs.') fail("Destination file not supported for batch decoding jobs.")
else: else:
if not FLAGS.recursive: if not FLAGS.recursive:
print("If you wish to recursively scan --src, then you must use --recursive") print(
"If you wish to recursively scan --src, then you must use --recursive"
)
wav_paths = glob.glob(src_path + "/*.wav") wav_paths = glob.glob(src_path + "/*.wav")
else: else:
wav_paths = glob.glob(src_path + "/**/*.wav") wav_paths = glob.glob(src_path + "/**/*.wav")
dst_paths = [path.replace('.wav','.tlog') for path in wav_paths] dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths,dst_paths) transcribe_many(wav_paths, dst_paths)
if __name__ == '__main__': if __name__ == "__main__":
create_flags() create_flags()
tf.app.flags.DEFINE_string('src', '', 'Source path to an audio file or directory or catalog file.' tf.app.flags.DEFINE_string(
'Catalog files should be formatted from DSAlign. A directory will' "src",
'be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be ' "",
'written in-place using the source filenames with ' "Source path to an audio file or directory or catalog file."
'suffix ".tlog" instead of ".wav".') "Catalog files should be formatted from DSAlign. A directory will"
tf.app.flags.DEFINE_string('dst', '', 'path for writing the transcription log or logs (.tlog). ' "be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
'If --src is a directory, this one also has to be a directory ' "written in-place using the source filenames with "
'and the required sub-dir tree of --src will get replicated.') 'suffix ".tlog" instead of ".wav".',
tf.app.flags.DEFINE_boolean('recursive', False, 'scan dir of audio recursively') )
tf.app.flags.DEFINE_boolean('force', False, 'Forces re-transcribing and overwriting of already existing ' tf.app.flags.DEFINE_string(
'transcription logs (.tlog)') "dst",
tf.app.flags.DEFINE_integer('vad_aggressiveness', 3, 'How aggressive (0=lowest, 3=highest) the VAD should ' "",
'split audio') "path for writing the transcription log or logs (.tlog). "
tf.app.flags.DEFINE_integer('batch_size', 40, 'Default batch size') "If --src is a directory, this one also has to be a directory "
tf.app.flags.DEFINE_float('outlier_duration_ms', 10000, 'Duration in ms after which samples are considered outliers') "and the required sub-dir tree of --src will get replicated.",
tf.app.flags.DEFINE_integer('outlier_batch_size', 1, 'Batch size for duration outliers (defaults to 1)') )
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
tf.app.flags.DEFINE_boolean(
"force",
False,
"Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)",
)
tf.app.flags.DEFINE_integer(
"vad_aggressiveness",
3,
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
)
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
tf.app.flags.DEFINE_float(
"outlier_duration_ms",
10000,
"Duration in ms after which samples are considered outliers",
)
tf.app.flags.DEFINE_integer(
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
)
tf.app.run(main) tf.app.run(main)