Merge pull request #2025 from coqui-ai/various-fixes

Docs fixes, SavedModel export, transcribe.py revival
This commit is contained in:
Reuben Morais 2021-11-19 16:10:20 +01:00 committed by GitHub
commit 3020949075
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 469 additions and 264 deletions

View File

@ -808,7 +808,7 @@ jobs:
- run: | - run: |
mkdir -p ${CI_ARTIFACTS_DIR} || true mkdir -p ${CI_ARTIFACTS_DIR} || true
- run: | - run: |
sudo apt-get install -y --no-install-recommends libopus0 sudo apt-get install -y --no-install-recommends libopus0 sox
- name: Run extra training tests - name: Run extra training tests
run: | run: |
python -m pip install coqui_stt_ctcdecoder-*.whl python -m pip install coqui_stt_ctcdecoder-*.whl

View File

@ -8,14 +8,14 @@ from coqui_stt_training.evaluate import test
# only one GPU for only one training sample # only one GPU for only one training sample
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
download_ldc("data/ldc93s1") download_ldc("data/smoke_test")
initialize_globals_from_args( initialize_globals_from_args(
load_train="init", load_train="init",
alphabet_config_path="data/alphabet.txt", alphabet_config_path="data/alphabet.txt",
train_files=["data/ldc93s1/ldc93s1.csv"], train_files=["data/smoke_test/ldc93s1.csv"],
dev_files=["data/ldc93s1/ldc93s1.csv"], dev_files=["data/smoke_test/ldc93s1.csv"],
test_files=["data/ldc93s1/ldc93s1.csv"], test_files=["data/smoke_test/ldc93s1.csv"],
augment=["time_mask"], augment=["time_mask"],
n_hidden=100, n_hidden=100,
epochs=200, epochs=200,

View File

@ -5,9 +5,9 @@ if [ ! -f train.py ]; then
exit 1 exit 1
fi; fi;
if [ ! -f "data/ldc93s1/ldc93s1.csv" ]; then if [ ! -f "data/smoke_test/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/ldc93s1." echo "Downloading and preprocessing LDC93S1 example data, saving in ./data/smoke_test."
python -u bin/import_ldc93s1.py ./data/ldc93s1 python -u bin/import_ldc93s1.py ./data/smoke_test
fi; fi;
if [ -d "${COMPUTE_KEEP_DIR}" ]; then if [ -d "${COMPUTE_KEEP_DIR}" ]; then
@ -23,8 +23,8 @@ export CUDA_VISIBLE_DEVICES=0
python -m coqui_stt_training.train \ python -m coqui_stt_training.train \
--alphabet_config_path "data/alphabet.txt" \ --alphabet_config_path "data/alphabet.txt" \
--show_progressbar false \ --show_progressbar false \
--train_files data/ldc93s1/ldc93s1.csv \ --train_files data/smoke_test/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \ --test_files data/smoke_test/ldc93s1.csv \
--train_batch_size 1 \ --train_batch_size 1 \
--test_batch_size 1 \ --test_batch_size 1 \
--n_hidden 100 \ --n_hidden 100 \

View File

@ -16,7 +16,7 @@ mkdir -p /tmp/train_tflite || true
set -o pipefail set -o pipefail
python -m pip install --upgrade pip setuptools wheel | cat python -m pip install --upgrade pip setuptools wheel | cat
python -m pip install --upgrade . | cat python -m pip install --upgrade ".[transcribe]" | cat
set +o pipefail set +o pipefail
# Prepare correct arguments for training # Prepare correct arguments for training
@ -72,3 +72,20 @@ time python ./bin/run-ldc93s1.py
# Training graph inference # Training graph inference
time ./bin/run-ci-ldc93s1_singleshotinference.sh time ./bin/run-ci-ldc93s1_singleshotinference.sh
# transcribe module
time python -m coqui_stt_training.transcribe \
--src "data/smoke_test/LDC93S1.wav" \
--dst ${CI_ARTIFACTS_DIR}/transcribe.log \
--n_hidden 100 \
--scorer_path "data/smoke_test/pruned_lm.scorer"
#TODO: investigate why this is hanging in CI
#mkdir /tmp/transcribe_dir
#cp data/smoke_test/LDC93S1.wav /tmp/transcribe_dir
#time python -m coqui_stt_training.transcribe \
# --src "/tmp/transcribe_dir/" \
# --n_hidden 100 \
# --scorer_path "data/smoke_test/pruned_lm.scorer"
#
#for i in data/smoke_test/*.tlog; do echo $i; cat $i; echo; done

View File

@ -1,4 +1,4 @@
.. _c-usage: .. _c-api:
C API C API
===== =====

View File

@ -16,7 +16,7 @@ You can deploy 🐸STT models either via a command-line client or a language bin
* :ref:`The Node.JS package + language binding <nodejs-usage>` * :ref:`The Node.JS package + language binding <nodejs-usage>`
* :ref:`The Android libstt AAR package <android-usage>` * :ref:`The Android libstt AAR package <android-usage>`
* :ref:`The command-line client <cli-usage>` * :ref:`The command-line client <cli-usage>`
* :ref:`The native C API <c-usage>` * :ref:`The C API <c-usage>`
.. _download-models: .. _download-models:
@ -172,7 +172,7 @@ This will link all .aar files in the ``libs`` directory you just created, includ
Using the command-line client Using the command-line client
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The pre-built binaries for the ``stt`` command-line (compiled C++) client are available in the ``native_client.tar.xz`` archive for your desired platform. You can download the archive from our `releases page <https://github.com/coqui-ai/STT/releases>`_. The pre-built binaries for the ``stt`` command-line (compiled C++) client are available in the ``native_client.*.tar.xz`` archive for your desired platform (where the * is the appropriate identifier for the platform you want to run on). You can download the archive from our `releases page <https://github.com/coqui-ai/STT/releases>`_.
Assuming you have :ref:`downloaded the pre-trained models <download-models>`, you can use the client as such: Assuming you have :ref:`downloaded the pre-trained models <download-models>`, you can use the client as such:
@ -182,6 +182,15 @@ Assuming you have :ref:`downloaded the pre-trained models <download-models>`, yo
See the help output with ``./stt -h`` for more details. See the help output with ``./stt -h`` for more details.
.. _c-usage:
Using the C API
^^^^^^^^^^^^^^^
Alongside the pre-built binaries for the ``stt`` command-line client described :ref:`above <cli-usage>`, in the same ``native_client.*.tar.xz`` platform-specific archive, you'll find the ``coqui-stt.h`` header file as well as the pre-built shared libraries needed to use the 🐸STT C API. You can download the archive from our `releases page <https://github.com/coqui-ai/STT/releases>`_.
Then, simply include the header file and link against the shared libraries in your project, and you should be able to use the C API. Reference documentation is available in :ref:`c-api`.
Installing bindings from source Installing bindings from source
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -78,8 +78,8 @@
"def download_sample_data():\n", "def download_sample_data():\n",
" data_dir=\"english/\"\n", " data_dir=\"english/\"\n",
" # Download data + alphabet\n", " # Download data + alphabet\n",
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n", " audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.wav\")\n",
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n", " transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/LDC93S1.txt\")\n",
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n", " alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
" # Format data\n", " # Format data\n",
" with open(transcript_file, \"r\") as fin:\n", " with open(transcript_file, \"r\") as fin:\n",

View File

@ -69,6 +69,9 @@ def main():
python_requires=">=3.5, <4", python_requires=">=3.5, <4",
install_requires=install_requires, install_requires=install_requires,
include_package_data=True, include_package_data=True,
extras_require={
"transcribe": ["webrtcvad"],
},
) )

View File

@ -9,13 +9,15 @@ DESIRED_LOG_LEVEL = (
) )
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
import shutil import shutil
from .deepspeech_model import create_inference_graph from .deepspeech_model import create_inference_graph, create_model
from .util.checkpoints import load_graph_for_evaluation from .util.checkpoints import load_graph_for_evaluation
from .util.config import Config, initialize_globals_from_cli, log_error, log_info from .util.config import Config, initialize_globals_from_cli, log_error, log_info
from .util.feeding import wavfile_bytes_to_features
from .util.io import ( from .util.io import (
open_remote, open_remote,
rmtree_remote, rmtree_remote,
@ -35,6 +37,9 @@ def export():
""" """
log_info("Exporting the model...") log_info("Exporting the model...")
if Config.export_savedmodel:
return export_savedmodel()
tfv1.reset_default_graph() tfv1.reset_default_graph()
inputs, outputs, _ = create_inference_graph( inputs, outputs, _ = create_inference_graph(
@ -172,6 +177,72 @@ def export():
) )
def export_savedmodel():
tfv1.reset_default_graph()
with tfv1.Session(config=Config.session_config) as session:
input_wavfile_contents = tf.placeholder(tf.string)
features, features_len = wavfile_bytes_to_features(input_wavfile_contents)
previous_state_c = tf.zeros([1, Config.n_cell_dim], tf.float32)
previous_state_h = tf.zeros([1, Config.n_cell_dim], tf.float32)
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
previous_state_c, previous_state_h
)
# Add batch dimension
features = tf.expand_dims(features, 0)
features_len = tf.expand_dims(features_len, 0)
# One rate per layer
no_dropout = [None] * 6
logits, layers = create_model(
batch_x=features,
batch_size=1,
seq_length=features_len,
dropout=no_dropout,
previous_state=previous_state,
)
# Restore variables from training checkpoint
load_graph_for_evaluation(session)
probs = tf.nn.softmax(logits)
# Remove batch dimension
squeezed = tf.squeeze(probs)
builder = tfv1.saved_model.builder.SavedModelBuilder(Config.export_dir)
input_file_tinfo = tfv1.saved_model.utils.build_tensor_info(
input_wavfile_contents
)
output_probs_tinfo = tfv1.saved_model.utils.build_tensor_info(squeezed)
forward_sig = tfv1.saved_model.signature_def_utils.build_signature_def(
inputs={
"input_wavfile": input_file_tinfo,
},
outputs={
"probs": output_probs_tinfo,
},
method_name="forward",
)
builder.add_meta_graph_and_variables(
session,
[tfv1.saved_model.tag_constants.SERVING],
signature_def_map={
tfv1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: forward_sig
},
)
builder.save()
log_info(f"Exported SavedModel to {Config.export_dir}")
def package_zip(): def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
export_dir = os.path.join( export_dir = os.path.join(

View File

@ -0,0 +1,315 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This script is structured in a weird way, with delayed imports. This is due
# to the use of multiprocessing. TensorFlow cannot handle forking, and even with
# the spawn strategy set to "spawn" it still leads to weird problems, so we
# restructure the code so that TensorFlow is only imported inside the child
# processes.
import os
import sys
import glob
import itertools
import json
import multiprocessing
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, field
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from tqdm import tqdm
def fail(message, code=1):
print(f"E {message}")
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
log_level_index = (
sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
)
desired_log_level = (
sys.argv[log_level_index] if 0 < log_level_index < len(sys.argv) else "3"
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = desired_log_level
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from coqui_stt_training.train import create_model
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
from coqui_stt_training.util.config import Config
from coqui_stt_training.util.feeding import split_audio_file
initialize_transcribe_config()
scorer = None
if Config.scorer_path:
scorer = Scorer(
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=Config.batch_size,
aggressiveness=Config.vad_aggressiveness,
outlier_duration_ms=Config.outlier_duration_ms,
outlier_batch_size=Config.outlier_batch_size,
)
iterator = tfv1.data.make_one_shot_iterator(data_set)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
Config.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def step_function(job):
""" Wrap transcribe_file to unpack arguments from a single tuple """
idx, src, dst = job
transcribe_file(src, dst)
return idx, src, dst
def transcribe_many(src_paths, dst_paths):
from coqui_stt_training.util.config import Config, log_progress
pool = Pool(processes=min(cpu_count(), len(src_paths)))
# Create list of items to be processed: [(i, src_path[i], dst_paths[i])]
jobs = zip(itertools.count(), src_paths, dst_paths)
process_iterable = tqdm(
pool.imap_unordered(step_function, jobs),
desc="Transcribing files",
total=len(src_paths),
disable=not Config.show_progressbar,
)
for result in process_iterable:
idx, src, dst = result
log_progress(
f'Transcribed file {idx+1} of {len(src_paths)} from "{src}" to "{dst}"'
)
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
print(f'I Transcribed file "{src_path}" to "{dst_path}"')
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def transcribe():
from coqui_stt_training.util.config import Config
initialize_transcribe_config()
if not Config.src or not os.path.exists(Config.src):
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(Config.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not Config.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file
dst_path = (
os.path.abspath(Config.dst)
if Config.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path):
if Config.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if Config.recursive:
wav_paths = glob.glob(os.path.join(src_path, "**", "*.wav"))
else:
wav_paths = glob.glob(os.path.join(src_path, "*.wav"))
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
def initialize_transcribe_config():
from coqui_stt_training.util.config import (
BaseSttConfig,
initialize_globals_from_instance,
)
@dataclass
class TranscribeConfig(BaseSttConfig):
src: str = field(
default="",
metadata=dict(
help="Source path to an audio file or directory or catalog file. "
"Catalog files should be formatted from DSAlign. A directory "
"will be recursively searched for audio. If --dst not set, "
"transcription logs (.tlog) will be written in-place using the "
'source filenames with suffix ".tlog" instead of ".wav".'
),
)
dst: str = field(
default="",
metadata=dict(
help="path for writing the transcription log or logs (.tlog). "
"If --src is a directory, this one also has to be a directory "
"and the required sub-dir tree of --src will get replicated."
),
)
recursive: bool = field(
default=False,
metadata=dict(help="scan source directory recursively for audio"),
)
force: bool = field(
default=False,
metadata=dict(
help="Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)"
),
)
vad_aggressiveness: int = field(
default=3,
metadata=dict(help="VAD aggressiveness setting (0=lowest, 3=highest)"),
)
batch_size: int = field(
default=40,
metadata=dict(help="Default batch size"),
)
outlier_duration_ms: int = field(
default=10000,
metadata=dict(
help="Duration in ms after which samples are considered outliers"
),
)
outlier_batch_size: int = field(
default=1,
metadata=dict(help="Batch size for duration outliers (defaults to 1)"),
)
def __post_init__(self):
if os.path.isfile(self.src) and self.src.endswith(".catalog") and self.dst:
raise RuntimeError(
"Parameter --dst not supported if --src points to a catalog"
)
if os.path.isdir(self.src):
if self.dst:
raise RuntimeError(
"Destination path not supported for batch decoding jobs."
)
super().__post_init__()
config = TranscribeConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config)
def main():
from coqui_stt_training.util.helpers import check_ctcdecoder_version
try:
import webrtcvad
except ImportError:
print(
"E transcribe module requires webrtcvad, which cannot be imported. Install with pip install webrtcvad"
)
sys.exit(1)
check_ctcdecoder_version()
transcribe()
if __name__ == "__main__":
main()

View File

@ -75,9 +75,12 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
init_vars.add(v) init_vars.add(v)
load_vars -= init_vars load_vars -= init_vars
log_info(f"Vars to load: {list(sorted(v.op.name for v in load_vars))}")
for v in sorted(load_vars, key=lambda v: v.op.name): for v in sorted(load_vars, key=lambda v: v.op.name):
log_info("Loading variable from checkpoint: %s" % (v.op.name)) log_info(f"Getting tensor from variable: {v.op.name}")
v.load(ckpt.get_tensor(v.op.name), session=session) tensor = ckpt.get_tensor(v.op.name)
log_info(f"Loading tensor from checkpoint: {v.op.name}")
v.load(tensor, session=session)
for v in sorted(init_vars, key=lambda v: v.op.name): for v in sorted(init_vars, key=lambda v: v.op.name):
log_info("Initializing variable: %s" % (v.op.name)) log_info("Initializing variable: %s" % (v.op.name))

View File

@ -37,7 +37,7 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
@dataclass @dataclass
class _SttConfig(Coqpit): class BaseSttConfig(Coqpit):
def __post_init__(self): def __post_init__(self):
# Augmentations # Augmentations
self.augmentations = parse_augmentations(self.augment) self.augmentations = parse_augmentations(self.augment)
@ -587,6 +587,10 @@ class _SttConfig(Coqpit):
default=True, default=True,
metadata=dict(help="export a quantized model (optimized for size)"), metadata=dict(help="export a quantized model (optimized for size)"),
) )
export_savedmodel: bool = field(
default=False,
metadata=dict(help="export model in TF SavedModel format"),
)
n_steps: int = field( n_steps: int = field(
default=16, default=16,
metadata=dict( metadata=dict(
@ -831,16 +835,22 @@ class _SttConfig(Coqpit):
def initialize_globals_from_cli(): def initialize_globals_from_cli():
c = _SttConfig.init_from_argparse(arg_prefix="") c = BaseSttConfig.init_from_argparse(arg_prefix="")
_ConfigSingleton._config = c # pylint: disable=protected-access _ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_args(**override_args): def initialize_globals_from_args(**override_args):
# Update Config with new args # Update Config with new args
c = _SttConfig(**override_args) c = BaseSttConfig(**override_args)
_ConfigSingleton._config = c # pylint: disable=protected-access _ConfigSingleton._config = c # pylint: disable=protected-access
def initialize_globals_from_instance(config):
""" Initialize Config singleton from an existing Config instance (or subclass) """
assert isinstance(config, BaseSttConfig)
_ConfigSingleton._config = config # pylint: disable=protected-access
# Logging functions # Logging functions
# ================= # =================

View File

@ -84,6 +84,14 @@ def audiofile_to_features(
wav_filename, clock=0.0, train_phase=False, augmentations=None wav_filename, clock=0.0, train_phase=False, augmentations=None
): ):
samples = tf.io.read_file(wav_filename) samples = tf.io.read_file(wav_filename)
return wavfile_bytes_to_features(
samples, clock, train_phase, augmentations, sample_id=wav_filename
)
def wavfile_bytes_to_features(
samples, clock=0.0, train_phase=False, augmentations=None, sample_id=None
):
decoded = contrib_audio.decode_wav(samples, desired_channels=1) decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return audio_to_features( return audio_to_features(
decoded.audio, decoded.audio,
@ -91,7 +99,7 @@ def audiofile_to_features(
clock=clock, clock=clock,
train_phase=train_phase, train_phase=train_phase,
augmentations=augmentations, augmentations=augmentations,
sample_id=wav_filename, sample_id=sample_id,
) )

View File

@ -2,246 +2,15 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import json
import os
import sys
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow.compat.v1.logging as tflogging
import tensorflow as tf
tflogging.set_verbosity(tflogging.ERROR)
import logging
logging.getLogger("sox").setLevel(logging.ERROR)
import glob
from multiprocessing import Process, cpu_count
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
from coqui_stt_training.util.audio import AudioFile
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
from coqui_stt_training.util.feeding import split_audio_file
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import (
create_progressbar,
log_error,
log_info,
log_progress,
)
def fail(message, code=1):
log_error(message)
sys.exit(code)
def transcribe_file(audio_path, tlog_path):
from coqui_stt_training.train import ( # pylint: disable=cyclic-import,import-outside-toplevel
create_model,
)
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
initialize_globals_from_cli()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
num_processes = cpu_count()
except NotImplementedError:
num_processes = 1
with AudioFile(audio_path, as_path=True) as wav_path:
data_set = split_audio_file(
wav_path,
batch_size=FLAGS.batch_size,
aggressiveness=FLAGS.vad_aggressiveness,
outlier_duration_ms=FLAGS.outlier_duration_ms,
outlier_batch_size=FLAGS.outlier_batch_size,
)
iterator = tf.data.Iterator.from_structure(
data_set.output_types,
data_set.output_shapes,
output_classes=data_set.output_classes,
)
batch_time_start, batch_time_end, batch_x, batch_x_len = iterator.get_next()
no_dropout = [None] * 6
logits, _ = create_model(
batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout
)
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True:
try:
starts, ends, batch_logits, batch_lengths = session.run(
[batch_time_start, batch_time_end, transposed, batch_x_len]
)
except tf.errors.OutOfRangeError:
break
decoded = ctc_beam_search_decoder_batch(
batch_logits,
batch_lengths,
Config.alphabet,
FLAGS.beam_width,
num_processes=num_processes,
scorer=scorer,
)
decoded = list(d[0][1] for d in decoded)
transcripts.extend(zip(starts, ends, decoded))
transcripts.sort(key=lambda t: t[0])
transcripts = [
{"start": int(start), "end": int(end), "transcript": transcript}
for start, end, transcript in transcripts
]
with open(tlog_path, "w") as tlog_file:
json.dump(transcripts, tlog_file, default=float)
def transcribe_many(src_paths, dst_paths):
pbar = create_progressbar(
prefix="Transcribing files | ", max_value=len(src_paths)
).start()
for i in range(len(src_paths)):
p = Process(target=transcribe_file, args=(src_paths[i], dst_paths[i]))
p.start()
p.join()
log_progress(
'Transcribed file {} of {} from "{}" to "{}"'.format(
i + 1, len(src_paths), src_paths[i], dst_paths[i]
)
)
pbar.update(i)
pbar.finish()
def transcribe_one(src_path, dst_path):
transcribe_file(src_path, dst_path)
log_info('Transcribed file "{}" to "{}"'.format(src_path, dst_path))
def resolve(base_path, spec_path):
if spec_path is None:
return None
if not os.path.isabs(spec_path):
spec_path = os.path.join(base_path, spec_path)
return spec_path
def main(_):
if not FLAGS.src or not os.path.exists(FLAGS.src):
# path not given or non-existant
fail(
"You have to specify which file or catalog to transcribe via the --src flag."
)
else:
# path given and exists
src_path = os.path.abspath(FLAGS.src)
if os.path.isfile(src_path):
if src_path.endswith(".catalog"):
# Transcribe batch of files via ".catalog" file (from DSAlign)
if FLAGS.dst:
fail("Parameter --dst not supported if --src points to a catalog")
catalog_dir = os.path.dirname(src_path)
with open(src_path, "r") as catalog_file:
catalog_entries = json.load(catalog_file)
catalog_entries = [
(resolve(catalog_dir, e["audio"]), resolve(catalog_dir, e["tlog"]))
for e in catalog_entries
]
if any(map(lambda e: not os.path.isfile(e[0]), catalog_entries)):
fail("Missing source file(s) in catalog")
if not FLAGS.force and any(
map(lambda e: os.path.isfile(e[1]), catalog_entries)
):
fail(
"Destination file(s) from catalog already existing, use --force for overwriting"
)
if any(
map(
lambda e: not os.path.isdir(os.path.dirname(e[1])),
catalog_entries,
)
):
fail("Missing destination directory for at least one catalog entry")
src_paths, dst_paths = zip(*paths)
transcribe_many(src_paths, dst_paths)
else:
# Transcribe one file
dst_path = (
os.path.abspath(FLAGS.dst)
if FLAGS.dst
else os.path.splitext(src_path)[0] + ".tlog"
)
if os.path.isfile(dst_path):
if FLAGS.force:
transcribe_one(src_path, dst_path)
else:
fail(
'Destination file "{}" already existing - use --force for overwriting'.format(
dst_path
),
code=0,
)
elif os.path.isdir(os.path.dirname(dst_path)):
transcribe_one(src_path, dst_path)
else:
fail("Missing destination directory")
elif os.path.isdir(src_path):
# Transcribe all files in dir
print("Transcribing all WAV files in --src")
if FLAGS.dst:
fail("Destination file not supported for batch decoding jobs.")
else:
if not FLAGS.recursive:
print(
"If you wish to recursively scan --src, then you must use --recursive"
)
wav_paths = glob.glob(src_path + "/*.wav")
else:
wav_paths = glob.glob(src_path + "/**/*.wav")
dst_paths = [path.replace(".wav", ".tlog") for path in wav_paths]
transcribe_many(wav_paths, dst_paths)
if __name__ == "__main__": if __name__ == "__main__":
create_flags() print(
tf.app.flags.DEFINE_string( "Using the top level transcribe.py script is deprecated and will be removed "
"src", "in a future release. Instead use: python -m coqui_stt_training.transcribe"
"",
"Source path to an audio file or directory or catalog file."
"Catalog files should be formatted from DSAlign. A directory will"
"be recursively searched for audio. If --dst not set, transcription logs (.tlog) will be "
"written in-place using the source filenames with "
'suffix ".tlog" instead of ".wav".',
) )
tf.app.flags.DEFINE_string( try:
"dst", from coqui_stt_training import transcribe as stt_transcribe
"", except ImportError:
"path for writing the transcription log or logs (.tlog). " print("Training package is not installed. See training documentation.")
"If --src is a directory, this one also has to be a directory " raise
"and the required sub-dir tree of --src will get replicated.",
) stt_transcribe.main()
tf.app.flags.DEFINE_boolean("recursive", False, "scan dir of audio recursively")
tf.app.flags.DEFINE_boolean(
"force",
False,
"Forces re-transcribing and overwriting of already existing "
"transcription logs (.tlog)",
)
tf.app.flags.DEFINE_integer(
"vad_aggressiveness",
3,
"How aggressive (0=lowest, 3=highest) the VAD should " "split audio",
)
tf.app.flags.DEFINE_integer("batch_size", 40, "Default batch size")
tf.app.flags.DEFINE_float(
"outlier_duration_ms",
10000,
"Duration in ms after which samples are considered outliers",
)
tf.app.flags.DEFINE_integer(
"outlier_batch_size", 1, "Batch size for duration outliers (defaults to 1)"
)
tf.app.run(main)