diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 8da63bd3..bd2e918a 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -808,7 +808,7 @@ jobs: - run: | mkdir -p ${CI_ARTIFACTS_DIR} || true - run: | - sudo apt-get install -y --no-install-recommends libopus0 + sudo apt-get install -y --no-install-recommends libopus0 sox - name: Run extra training tests run: | python -m pip install coqui_stt_ctcdecoder-*.whl diff --git a/bin/run-ldc93s1.py b/bin/run-ldc93s1.py index b25cc998..0fead323 100755 --- a/bin/run-ldc93s1.py +++ b/bin/run-ldc93s1.py @@ -8,14 +8,14 @@ from coqui_stt_training.evaluate import test # only one GPU for only one training sample os.environ["CUDA_VISIBLE_DEVICES"] = "0" -download_ldc("data/ldc93s1") +download_ldc("data/smoke_test") initialize_globals_from_args( load_train="init", alphabet_config_path="data/alphabet.txt", - train_files=["data/ldc93s1/ldc93s1.csv"], - dev_files=["data/ldc93s1/ldc93s1.csv"], - test_files=["data/ldc93s1/ldc93s1.csv"], + train_files=["data/smoke_test/ldc93s1.csv"], + dev_files=["data/smoke_test/ldc93s1.csv"], + test_files=["data/smoke_test/ldc93s1.csv"], augment=["time_mask"], n_hidden=100, epochs=200, diff --git a/bin/run-ldc93s1.sh b/bin/run-ldc93s1.sh index 2bd80c59..34ffa0dd 100755 --- a/bin/run-ldc93s1.sh +++ b/bin/run-ldc93s1.sh @@ -5,9 +5,9 @@ if [ ! -f train.py ]; then exit 1 fi; -if [ ! -f "data/ldc93s1/ldc93s1.csv" ]; then - echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/ldc93s1." - python -u bin/import_ldc93s1.py ./data/ldc93s1 +if [ ! -f "data/smoke_test/ldc93s1.csv" ]; then + echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/smoke_test." + python -u bin/import_ldc93s1.py ./data/smoke_test fi; if [ -d "${COMPUTE_KEEP_DIR}" ]; then @@ -23,8 +23,8 @@ export CUDA_VISIBLE_DEVICES=0 python -m coqui_stt_training.train \ --alphabet_config_path "data/alphabet.txt" \ --show_progressbar false \ - --train_files data/ldc93s1/ldc93s1.csv \ - --test_files data/ldc93s1/ldc93s1.csv \ + --train_files data/smoke_test/ldc93s1.csv \ + --test_files data/smoke_test/ldc93s1.csv \ --train_batch_size 1 \ --test_batch_size 1 \ --n_hidden 100 \ diff --git a/ci_scripts/train-extra-tests.sh b/ci_scripts/train-extra-tests.sh index f538110d..82987474 100755 --- a/ci_scripts/train-extra-tests.sh +++ b/ci_scripts/train-extra-tests.sh @@ -16,7 +16,7 @@ mkdir -p /tmp/train_tflite || true set -o pipefail python -m pip install --upgrade pip setuptools wheel | cat -python -m pip install --upgrade . | cat +python -m pip install --upgrade ".[transcribe]" | cat set +o pipefail # Prepare correct arguments for training @@ -72,3 +72,20 @@ time python ./bin/run-ldc93s1.py # Training graph inference time ./bin/run-ci-ldc93s1_singleshotinference.sh + +# transcribe module +time python -m coqui_stt_training.transcribe \ + --src "data/smoke_test/LDC93S1.wav" \ + --dst ${CI_ARTIFACTS_DIR}/transcribe.log \ + --n_hidden 100 \ + --scorer_path "data/smoke_test/pruned_lm.scorer" + +#TODO: investigate why this is hanging in CI +#mkdir /tmp/transcribe_dir +#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir +#time python -m coqui_stt_training.transcribe \ +# --src "/tmp/transcribe_dir/" \ +# --n_hidden 100 \ +# --scorer_path "data/smoke_test/pruned_lm.scorer" +# +#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done diff --git a/doc/C-API.rst b/doc/C-API.rst index b76c06b8..1fd8bd4e 100644 --- a/doc/C-API.rst +++ b/doc/C-API.rst @@ -1,4 +1,4 @@ -.. _c-usage: +.. _c-api: C API ===== diff --git a/doc/DEPLOYMENT.rst b/doc/DEPLOYMENT.rst index 964e4606..0d60b2be 100644 --- a/doc/DEPLOYMENT.rst +++ b/doc/DEPLOYMENT.rst @@ -16,7 +16,7 @@ You can deploy 🐸STT models either via a command-line client or a language bin * :ref:`The Node.JS package + language binding ` * :ref:`The Android libstt AAR package ` * :ref:`The command-line client ` -* :ref:`The native C API ` +* :ref:`The C API ` .. _download-models: @@ -172,7 +172,7 @@ This will link all .aar files in the ``libs`` directory you just created, includ Using the command-line client ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The pre-built binaries for the ``stt`` command-line (compiled C++) client are available in the ``native_client.tar.xz`` archive for your desired platform. You can download the archive from our `releases page `_. +The pre-built binaries for the ``stt`` command-line (compiled C++) client are available in the ``native_client.*.tar.xz`` archive for your desired platform (where the * is the appropriate identifier for the platform you want to run on). You can download the archive from our `releases page `_. Assuming you have :ref:`downloaded the pre-trained models `, you can use the client as such: @@ -182,6 +182,15 @@ Assuming you have :ref:`downloaded the pre-trained models `, yo See the help output with ``./stt -h`` for more details. +.. _c-usage: + +Using the C API +^^^^^^^^^^^^^^^ + +Alongside the pre-built binaries for the ``stt`` command-line client described :ref:`above `, in the same ``native_client.*.tar.xz`` platform-specific archive, you'll find the ``coqui-stt.h`` header file as well as the pre-built shared libraries needed to use the 🐸STT C API. You can download the archive from our `releases page `_. + +Then, simply include the header file and link against the shared libraries in your project, and you should be able to use the C API. Reference documentation is available in :ref:`c-api`. + Installing bindings from source ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/notebooks/train_your_first_coqui_STT_model.ipynb b/notebooks/train_your_first_coqui_STT_model.ipynb index df885b2d..404cc58d 100644 --- a/notebooks/train_your_first_coqui_STT_model.ipynb +++ b/notebooks/train_your_first_coqui_STT_model.ipynb @@ -78,8 +78,8 @@ "def download_sample_data():\n", " data_dir=\"english/\"\n", " # Download data + alphabet\n", - " audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n", - " transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n", + " audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.wav\")\n", + " transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.txt\")\n", " alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n", " # Format data\n", " with open(transcript_file, \"r\") as fin:\n", diff --git a/setup.py b/setup.py index 9d850c58..8a3469df 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,9 @@ def main(): python_requires=">=3.5, <4", install_requires=install_requires, include_package_data=True, + extras_require={ + "transcribe": ["webrtcvad"], + }, ) diff --git a/training/coqui_stt_training/export.py b/training/coqui_stt_training/export.py index b1fedbbc..782835bf 100644 --- a/training/coqui_stt_training/export.py +++ b/training/coqui_stt_training/export.py @@ -9,13 +9,15 @@ DESIRED_LOG_LEVEL = ( ) os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL +import numpy as np import tensorflow as tf import tensorflow.compat.v1 as tfv1 import shutil -from .deepspeech_model import create_inference_graph +from .deepspeech_model import create_inference_graph, create_model from .util.checkpoints import load_graph_for_evaluation from .util.config import Config, initialize_globals_from_cli, log_error, log_info +from .util.feeding import wavfile_bytes_to_features from .util.io import ( open_remote, rmtree_remote, @@ -35,6 +37,9 @@ def export(): """ log_info("Exporting the model...") + if Config.export_savedmodel: + return export_savedmodel() + tfv1.reset_default_graph() inputs, outputs, _ = create_inference_graph( @@ -172,6 +177,72 @@ def export(): ) +def export_savedmodel(): + tfv1.reset_default_graph() + + with tfv1.Session(config=Config.session_config) as session: + input_wavfile_contents = tf.placeholder(tf.string) + + features, features_len = wavfile_bytes_to_features(input_wavfile_contents) + previous_state_c = tf.zeros([1, Config.n_cell_dim], tf.float32) + previous_state_h = tf.zeros([1, Config.n_cell_dim], tf.float32) + + previous_state = tf.nn.rnn_cell.LSTMStateTuple( + previous_state_c, previous_state_h + ) + + # Add batch dimension + features = tf.expand_dims(features, 0) + features_len = tf.expand_dims(features_len, 0) + + # One rate per layer + no_dropout = [None] * 6 + + logits, layers = create_model( + batch_x=features, + batch_size=1, + seq_length=features_len, + dropout=no_dropout, + previous_state=previous_state, + ) + + # Restore variables from training checkpoint + load_graph_for_evaluation(session) + + probs = tf.nn.softmax(logits) + + # Remove batch dimension + squeezed = tf.squeeze(probs) + + builder = tfv1.saved_model.builder.SavedModelBuilder(Config.export_dir) + + input_file_tinfo = tfv1.saved_model.utils.build_tensor_info( + input_wavfile_contents + ) + output_probs_tinfo = tfv1.saved_model.utils.build_tensor_info(squeezed) + + forward_sig = tfv1.saved_model.signature_def_utils.build_signature_def( + inputs={ + "input_wavfile": input_file_tinfo, + }, + outputs={ + "probs": output_probs_tinfo, + }, + method_name="forward", + ) + + builder.add_meta_graph_and_variables( + session, + [tfv1.saved_model.tag_constants.SERVING], + signature_def_map={ + tfv1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: forward_sig + }, + ) + + builder.save() + log_info(f"Exported SavedModel to {Config.export_dir}") + + def package_zip(): # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip export_dir = os.path.join( diff --git a/training/coqui_stt_training/transcribe.py b/training/coqui_stt_training/transcribe.py new file mode 100755 index 00000000..a761d516 --- /dev/null +++ b/training/coqui_stt_training/transcribe.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# This script is structured in a weird way, with delayed imports. This is due +# to the use of multiprocessing. TensorFlow cannot handle forking, and even with +# the spawn strategy set to "spawn" it still leads to weird problems, so we +# restructure the code so that TensorFlow is only imported inside the child +# processes. + +import os +import sys +import glob +import itertools +import json +import multiprocessing +from multiprocessing import Pool, cpu_count +from dataclasses import dataclass, field + +from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch +from tqdm import tqdm + + +def fail(message, code=1): + print(f"E {message}") + sys.exit(code) + + +def transcribe_file(audio_path, tlog_path): + log_level_index = ( + sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 + ) + desired_log_level = ( + sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3" + ) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level + + import tensorflow as tf + import tensorflow.compat.v1 as tfv1 + + from coqui_stt_training.train import create_model + from coqui_stt_training.util.audio import AudioFile + from coqui_stt_training.util.checkpoints import load_graph_for_evaluation + from coqui_stt_training.util.config import Config + from coqui_stt_training.util.feeding import split_audio_file + + initialize_transcribe_config() + + scorer = None + if Config.scorer_path: + scorer = Scorer( + Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet + ) + + try: + num_processes = cpu_count() + except NotImplementedError: + num_processes = 1 + + with AudioFile(audio_path, as_path=True) as wav_path: + data_set = split_audio_file( + wav_path, + batch_size=Config.batch_size, + aggressiveness=Config.vad_aggressiveness, + outlier_duration_ms=Config.outlier_duration_ms, + outlier_batch_size=Config.outlier_batch_size, + ) + iterator = tfv1.data.make_one_shot_iterator(data_set) + batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next() + no_dropout = [None] * 6 + logits, _ = create_model( + batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout + ) + transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) + tf.train.get_or_create_global_step() + with tf.Session(config=Config.session_config) as session: + load_graph_for_evaluation(session) + transcripts = [] + while True: + try: + starts, ends, batch_logits, batch_lengths = session.run( + [batch_time_start, batch_time_end, transposed, batch_x_len] + ) + except tf.errors.OutOfRangeError: + break + decoded = ctc_beam_search_decoder_batch( + batch_logits, + batch_lengths, + Config.alphabet, + Config.beam_width, + num_processes=num_processes, + scorer=scorer, + ) + decoded = list(d[0][1] for d in decoded) + transcripts.extend(zip(starts, ends, decoded)) + transcripts.sort(key=lambda t: t[0]) + transcripts = [ + {"start": int(start), "end": int(end), "transcript": transcript} + for start, end, transcript in transcripts + ] + with open(tlog_path, "w") as tlog_file: + json.dump(transcripts, tlog_file, default=float) + + +def step_function(job): + """ Wrap transcribe_file to unpack arguments from a single tuple """ + idx, src, dst = job + transcribe_file(src, dst) + return idx, src, dst + + +def transcribe_many(src_paths, dst_paths): + from coqui_stt_training.util.config import Config, log_progress + + pool = Pool(processes=min(cpu_count(), len(src_paths))) + + # Create list of items to be processed: [(i, src_path[i], dst_paths[i])] + jobs = zip(itertools.count(), src_paths, dst_paths) + + process_iterable = tqdm( + pool.imap_unordered(step_function, jobs), + desc="Transcribing files", + total=len(src_paths), + disable=not Config.show_progressbar, + ) + + for result in process_iterable: + idx, src, dst = result + log_progress( + f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"' + ) + + +def transcribe_one(src_path, dst_path): + transcribe_file(src_path, dst_path) + print(f'I Transcribed file "{src_path}" to "{dst_path}"') + + +def resolve(base_path, spec_path): + if spec_path is None: + return None + if not os.path.isabs(spec_path): + spec_path = os.path.join(base_path, spec_path) + return spec_path + + +def transcribe(): + from coqui_stt_training.util.config import Config + + initialize_transcribe_config() + + if not Config.src or not os.path.exists(Config.src): + # path not given or non-existant + fail( + "You have to specify which file or catalog to transcribe via the --src flag." + ) + else: + # path given and exists + src_path = os.path.abspath(Config.src) + if os.path.isfile(src_path): + if src_path.endswith(".catalog"): + # Transcribe batch of files via ".catalog" file (from DSAlign) + catalog_dir = os.path.dirname(src_path) + with open(src_path, "r") as catalog_file: + catalog_entries = json.load(catalog_file) + catalog_entries = [ + (resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"])) + for e in catalog_entries + ] + if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)): + fail("Missing source file(s) in catalog") + if not Config.force and any( + map(lambda e: os.path.isfile(e[1]), catalog_entries) + ): + fail( + "Destination file(s) from catalog already existing, use --force for overwriting" + ) + if any( + map( + lambda e: not os.path.isdir(os.path.dirname(e[1])), + catalog_entries, + ) + ): + fail("Missing destination directory for at least one catalog entry") + src_paths, dst_paths = zip(*paths) + transcribe_many(src_paths, dst_paths) + else: + # Transcribe one file + dst_path = ( + os.path.abspath(Config.dst) + if Config.dst + else os.path.splitext(src_path)[0] + ".tlog" + ) + if os.path.isfile(dst_path): + if Config.force: + transcribe_one(src_path, dst_path) + else: + fail( + 'Destination file "{}" already existing - use --force for overwriting'.format( + dst_path + ), + code=0, + ) + elif os.path.isdir(os.path.dirname(dst_path)): + transcribe_one(src_path, dst_path) + else: + fail("Missing destination directory") + elif os.path.isdir(src_path): + # Transcribe all files in dir + print("Transcribing all WAV files in --src") + if Config.recursive: + wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav")) + else: + wav_paths = glob.glob(os.path.join(src_path, "*.wav")) + dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths] + transcribe_many(wav_paths, dst_paths) + + +def initialize_transcribe_config(): + from coqui_stt_training.util.config import ( + BaseSttConfig, + initialize_globals_from_instance, + ) + + @dataclass + class TranscribeConfig(BaseSttConfig): + src: str = field( + default="", + metadata=dict( + help="Source path to an audio file or directory or catalog file. " + "Catalog files should be formatted from DSAlign. A directory " + "will be recursively searched for audio. If --dst not set, " + "transcription logs (.tlog) will be written in-place using the " + 'source filenames with suffix ".tlog" instead of ".wav".' + ), + ) + + dst: str = field( + default="", + metadata=dict( + help="path for writing the transcription log or logs (.tlog). " + "If --src is a directory, this one also has to be a directory " + "and the required sub-dir tree of --src will get replicated." + ), + ) + + recursive: bool = field( + default=False, + metadata=dict(help="scan source directory recursively for audio"), + ) + + force: bool = field( + default=False, + metadata=dict( + help="Forces re-transcribing and overwriting of already existing " + "transcription logs (.tlog)" + ), + ) + + vad_aggressiveness: int = field( + default=3, + metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"), + ) + + batch_size: int = field( + default=40, + metadata=dict(help="Default batch size"), + ) + + outlier_duration_ms: int = field( + default=10000, + metadata=dict( + help="Duration in ms after which samples are considered outliers" + ), + ) + + outlier_batch_size: int = field( + default=1, + metadata=dict(help="Batch size for duration outliers (defaults to 1)"), + ) + + def __post_init__(self): + if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst: + raise RuntimeError( + "Parameter --dst not supported if --src points to a catalog" + ) + + if os.path.isdir(self.src): + if self.dst: + raise RuntimeError( + "Destination path not supported for batch decoding jobs." + ) + + super().__post_init__() + + config = TranscribeConfig.init_from_argparse(arg_prefix="") + initialize_globals_from_instance(config) + + +def main(): + from coqui_stt_training.util.helpers import check_ctcdecoder_version + + try: + import webrtcvad + except ImportError: + print( + "E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad" + ) + sys.exit(1) + + check_ctcdecoder_version() + transcribe() + + +if __name__ == "__main__": + main() diff --git a/training/coqui_stt_training/util/checkpoints.py b/training/coqui_stt_training/util/checkpoints.py index 434d403c..9d5452ae 100644 --- a/training/coqui_stt_training/util/checkpoints.py +++ b/training/coqui_stt_training/util/checkpoints.py @@ -75,9 +75,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= init_vars.add(v) load_vars -= init_vars + log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}") for v in sorted(load_vars, key=lambda v: v.op.name): - log_info("Loading variable from checkpoint: %s" % (v.op.name)) - v.load(ckpt.get_tensor(v.op.name), session=session) + log_info(f"Getting tensor from variable: {v.op.name}") + tensor = ckpt.get_tensor(v.op.name) + log_info(f"Loading tensor from checkpoint: {v.op.name}") + v.load(tensor, session=session) for v in sorted(init_vars, key=lambda v: v.op.name): log_info("Initializing variable: %s" % (v.op.name)) diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 008953a8..9ca10b80 100644 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -37,7 +37,7 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name @dataclass -class _SttConfig(Coqpit): +class BaseSttConfig(Coqpit): def __post_init__(self): # Augmentations self.augmentations = parse_augmentations(self.augment) @@ -587,6 +587,10 @@ class _SttConfig(Coqpit): default=True, metadata=dict(help="export a quantized model (optimized for size)"), ) + export_savedmodel: bool = field( + default=False, + metadata=dict(help="export model in TF SavedModel format"), + ) n_steps: int = field( default=16, metadata=dict( @@ -831,16 +835,22 @@ class _SttConfig(Coqpit): def initialize_globals_from_cli(): - c = _SttConfig.init_from_argparse(arg_prefix="") + c = BaseSttConfig.init_from_argparse(arg_prefix="") _ConfigSingleton._config = c # pylint: disable=protected-access def initialize_globals_from_args(**override_args): # Update Config with new args - c = _SttConfig(**override_args) + c = BaseSttConfig(**override_args) _ConfigSingleton._config = c # pylint: disable=protected-access +def initialize_globals_from_instance(config): + """ Initialize Config singleton from an existing Config instance (or subclass) """ + assert isinstance(config, BaseSttConfig) + _ConfigSingleton._config = config # pylint: disable=protected-access + + # Logging functions # ================= diff --git a/training/coqui_stt_training/util/feeding.py b/training/coqui_stt_training/util/feeding.py index bf506375..5a0d8109 100644 --- a/training/coqui_stt_training/util/feeding.py +++ b/training/coqui_stt_training/util/feeding.py @@ -84,6 +84,14 @@ def audiofile_to_features( wav_filename, clock=0.0, train_phase=False, augmentations=None ): samples = tf.io.read_file(wav_filename) + return wavfile_bytes_to_features( + samples, clock, train_phase, augmentations, sample_id=wav_filename + ) + + +def wavfile_bytes_to_features( + samples, clock=0.0, train_phase=False, augmentations=None, sample_id=None +): decoded = contrib_audio.decode_wav(samples, desired_channels=1) return audio_to_features( decoded.audio, @@ -91,7 +99,7 @@ def audiofile_to_features( clock=clock, train_phase=train_phase, augmentations=augmentations, - sample_id=wav_filename, + sample_id=sample_id, ) diff --git a/transcribe.py b/transcribe.py index 2792ae2f..4458ad8a 100755 --- a/transcribe.py +++ b/transcribe.py @@ -2,246 +2,15 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function -import json -import os -import sys - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow.compat.v1.logging as tflogging - -import tensorflow as tf - -tflogging.set_verbosity(tflogging.ERROR) -import logging - -logging.getLogger("sox").setLevel(logging.ERROR) -import glob -from multiprocessing import Process, cpu_count - -from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch -from coqui_stt_training.util.audio import AudioFile -from coqui_stt_training.util.config import Config, initialize_globals_from_cli -from coqui_stt_training.util.feeding import split_audio_file -from coqui_stt_training.util.flags import FLAGS, create_flags -from coqui_stt_training.util.logging import ( - create_progressbar, - log_error, - log_info, - log_progress, -) - - -def fail(message, code=1): - log_error(message) - sys.exit(code) - - -def transcribe_file(audio_path, tlog_path): - from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel - create_model, - ) - from coqui_stt_training.util.checkpoints import load_graph_for_evaluation - - initialize_globals_from_cli() - - scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet) - try: - num_processes = cpu_count() - except NotImplementedError: - num_processes = 1 - with AudioFile(audio_path, as_path=True) as wav_path: - data_set = split_audio_file( - wav_path, - batch_size=FLAGS.batch_size, - aggressiveness=FLAGS.vad_aggressiveness, - outlier_duration_ms=FLAGS.outlier_duration_ms, - outlier_batch_size=FLAGS.outlier_batch_size, - ) - iterator = tf.data.Iterator.from_structure( - data_set.output_types, - data_set.output_shapes, - output_classes=data_set.output_classes, - ) - batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next() - no_dropout = [None] * 6 - logits, _ = create_model( - batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout - ) - transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2])) - tf.train.get_or_create_global_step() - with tf.Session(config=Config.session_config) as session: - load_graph_for_evaluation(session) - session.run(iterator.make_initializer(data_set)) - transcripts = [] - while True: - try: - starts, ends, batch_logits, batch_lengths = session.run( - [batch_time_start, batch_time_end, transposed, batch_x_len] - ) - except tf.errors.OutOfRangeError: - break - decoded = ctc_beam_search_decoder_batch( - batch_logits, - batch_lengths, - Config.alphabet, - FLAGS.beam_width, - num_processes=num_processes, - scorer=scorer, - ) - decoded = list(d[0][1] for d in decoded) - transcripts.extend(zip(starts, ends, decoded)) - transcripts.sort(key=lambda t: t[0]) - transcripts = [ - {"start": int(start), "end": int(end), "transcript": transcript} - for start, end, transcript in transcripts - ] - with open(tlog_path, "w") as tlog_file: - json.dump(transcripts, tlog_file, default=float) - - -def transcribe_many(src_paths, dst_paths): - pbar = create_progressbar( - prefix="Transcribing files | ", max_value=len(src_paths) - ).start() - for i in range(len(src_paths)): - p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i])) - p.start() - p.join() - log_progress( - 'Transcribed file {} of {} from "{}" to "{}"'.format( - i + 1, len(src_paths), src_paths[i], dst_paths[i] - ) - ) - pbar.update(i) - pbar.finish() - - -def transcribe_one(src_path, dst_path): - transcribe_file(src_path, dst_path) - log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path)) - - -def resolve(base_path, spec_path): - if spec_path is None: - return None - if not os.path.isabs(spec_path): - spec_path = os.path.join(base_path, spec_path) - return spec_path - - -def main(_): - if not FLAGS.src or not os.path.exists(FLAGS.src): - # path not given or non-existant - fail( - "You have to specify which file or catalog to transcribe via the --src flag." - ) - else: - # path given and exists - src_path = os.path.abspath(FLAGS.src) - if os.path.isfile(src_path): - if src_path.endswith(".catalog"): - # Transcribe batch of files via ".catalog" file (from DSAlign) - if FLAGS.dst: - fail("Parameter --dst not supported if --src points to a catalog") - catalog_dir = os.path.dirname(src_path) - with open(src_path, "r") as catalog_file: - catalog_entries = json.load(catalog_file) - catalog_entries = [ - (resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"])) - for e in catalog_entries - ] - if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)): - fail("Missing source file(s) in catalog") - if not FLAGS.force and any( - map(lambda e: os.path.isfile(e[1]), catalog_entries) - ): - fail( - "Destination file(s) from catalog already existing, use --force for overwriting" - ) - if any( - map( - lambda e: not os.path.isdir(os.path.dirname(e[1])), - catalog_entries, - ) - ): - fail("Missing destination directory for at least one catalog entry") - src_paths, dst_paths = zip(*paths) - transcribe_many(src_paths, dst_paths) - else: - # Transcribe one file - dst_path = ( - os.path.abspath(FLAGS.dst) - if FLAGS.dst - else os.path.splitext(src_path)[0] + ".tlog" - ) - if os.path.isfile(dst_path): - if FLAGS.force: - transcribe_one(src_path, dst_path) - else: - fail( - 'Destination file "{}" already existing - use --force for overwriting'.format( - dst_path - ), - code=0, - ) - elif os.path.isdir(os.path.dirname(dst_path)): - transcribe_one(src_path, dst_path) - else: - fail("Missing destination directory") - elif os.path.isdir(src_path): - # Transcribe all files in dir - print("Transcribing all WAV files in --src") - if FLAGS.dst: - fail("Destination file not supported for batch decoding jobs.") - else: - if not FLAGS.recursive: - print( - "If you wish to recursively scan --src, then you must use --recursive" - ) - wav_paths = glob.glob(src_path + "/*.wav") - else: - wav_paths = glob.glob(src_path + "/**/*.wav") - dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths] - transcribe_many(wav_paths, dst_paths) - - if __name__ == "__main__": - create_flags() - tf.app.flags.DEFINE_string( - "src", - "", - "Source path to an audio file or directory or catalog file." - "Catalog files should be formatted from DSAlign. A directory will" - "be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be " - "written in-place using the source filenames with " - 'suffix ".tlog" instead of ".wav".', + print( + "Using the top level transcribe.py script is deprecated and will be removed " + "in a future release. Instead use: python -m coqui_stt_training.transcribe" ) - tf.app.flags.DEFINE_string( - "dst", - "", - "path for writing the transcription log or logs (.tlog). " - "If --src is a directory, this one also has to be a directory " - "and the required sub-dir tree of --src will get replicated.", - ) - tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively") - tf.app.flags.DEFINE_boolean( - "force", - False, - "Forces re-transcribing and overwriting of already existing " - "transcription logs (.tlog)", - ) - tf.app.flags.DEFINE_integer( - "vad_aggressiveness", - 3, - "How aggressive (0=lowest, 3=highest) the VAD should " "split audio", - ) - tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size") - tf.app.flags.DEFINE_float( - "outlier_duration_ms", - 10000, - "Duration in ms after which samples are considered outliers", - ) - tf.app.flags.DEFINE_integer( - "outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)" - ) - tf.app.run(main) + try: + from coqui_stt_training import transcribe as stt_transcribe + except ImportError: + print("Training package is not installed. See training documentation.") + raise + + stt_transcribe.main()