Merge remote-tracking branch 'upstream/main' into intermediateDecodeExpensive
This commit is contained in:
commit
36ea530d42
@ -44,6 +44,7 @@ async function getGoodArtifacts(client, owner, repo, releaseId, name) {
|
||||
|
||||
async function main() {
|
||||
try {
|
||||
const token = core.getInput("github_token", { required: true });
|
||||
const [owner, repo] = core.getInput("repo", { required: true }).split("/");
|
||||
const path = core.getInput("path", { required: true });
|
||||
const name = core.getInput("name");
|
||||
@ -51,6 +52,7 @@ async function main() {
|
||||
const releaseTag = core.getInput("release-tag");
|
||||
const OctokitWithThrottling = GitHub.plugin(throttling);
|
||||
const client = new OctokitWithThrottling({
|
||||
auth: token,
|
||||
throttle: {
|
||||
onRateLimit: (retryAfter, options) => {
|
||||
console.log(
|
||||
@ -61,6 +63,9 @@ async function main() {
|
||||
if (options.request.retryCount <= 2) {
|
||||
console.log(`Retrying after ${retryAfter} seconds!`);
|
||||
return true;
|
||||
} else {
|
||||
console.log("Exhausted 2 retries");
|
||||
core.setFailed("Exhausted 2 retries");
|
||||
}
|
||||
},
|
||||
onAbuseLimit: (retryAfter, options) => {
|
||||
@ -68,6 +73,7 @@ async function main() {
|
||||
console.log(
|
||||
`Abuse detected for request ${options.method} ${options.url}`
|
||||
);
|
||||
core.setFailed(`GitHub REST API Abuse detected for request ${options.method} ${options.url}`)
|
||||
},
|
||||
},
|
||||
});
|
||||
@ -108,6 +114,7 @@ async function main() {
|
||||
await Download(artifact.url, dir, {
|
||||
headers: {
|
||||
"Accept": "application/octet-stream",
|
||||
"Authorization": `token ${token}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ async function getGoodArtifacts(client, owner, repo, releaseId, name) {
|
||||
|
||||
async function main() {
|
||||
try {
|
||||
const token = core.getInput("github_token", { required: true });
|
||||
const [owner, repo] = core.getInput("repo", { required: true }).split("/");
|
||||
const path = core.getInput("path", { required: true });
|
||||
const name = core.getInput("name");
|
||||
@ -44,6 +45,7 @@ async function main() {
|
||||
const releaseTag = core.getInput("release-tag");
|
||||
const OctokitWithThrottling = GitHub.plugin(throttling);
|
||||
const client = new OctokitWithThrottling({
|
||||
auth: token,
|
||||
throttle: {
|
||||
onRateLimit: (retryAfter, options) => {
|
||||
console.log(
|
||||
@ -54,6 +56,9 @@ async function main() {
|
||||
if (options.request.retryCount <= 2) {
|
||||
console.log(`Retrying after ${retryAfter} seconds!`);
|
||||
return true;
|
||||
} else {
|
||||
console.log("Exhausted 2 retries");
|
||||
core.setFailed("Exhausted 2 retries");
|
||||
}
|
||||
},
|
||||
onAbuseLimit: (retryAfter, options) => {
|
||||
@ -61,6 +66,7 @@ async function main() {
|
||||
console.log(
|
||||
`Abuse detected for request ${options.method} ${options.url}`
|
||||
);
|
||||
core.setFailed(`GitHub REST API Abuse detected for request ${options.method} ${options.url}`)
|
||||
},
|
||||
},
|
||||
});
|
||||
@ -101,6 +107,7 @@ async function main() {
|
||||
await Download(artifact.url, dir, {
|
||||
headers: {
|
||||
"Accept": "application/octet-stream",
|
||||
"Authorization": `token ${token}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
5
.github/actions/host-build/action.yml
vendored
5
.github/actions/host-build/action.yml
vendored
@ -5,11 +5,8 @@ inputs:
|
||||
description: "Target arch for loading script (host/armv7/aarch64)"
|
||||
required: false
|
||||
default: "host"
|
||||
flavor:
|
||||
description: "Build flavor"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- run: ./ci_scripts/${{ inputs.arch }}-build.sh ${{ inputs.flavor }}
|
||||
- run: ./ci_scripts/${{ inputs.arch }}-build.sh
|
||||
shell: bash
|
||||
|
10
.github/actions/numpy_vers/action.yml
vendored
10
.github/actions/numpy_vers/action.yml
vendored
@ -28,15 +28,15 @@ runs:
|
||||
case "${{ inputs.pyver }}" in
|
||||
3.7*)
|
||||
NUMPY_BUILD_VERSION="==1.14.5"
|
||||
NUMPY_DEP_VERSION=">=1.14.5"
|
||||
NUMPY_DEP_VERSION=">=1.14.5,<=1.19.4"
|
||||
;;
|
||||
3.8*)
|
||||
NUMPY_BUILD_VERSION="==1.17.3"
|
||||
NUMPY_DEP_VERSION=">=1.17.3"
|
||||
NUMPY_DEP_VERSION=">=1.17.3,<=1.19.4"
|
||||
;;
|
||||
3.9*)
|
||||
NUMPY_BUILD_VERSION="==1.19.4"
|
||||
NUMPY_DEP_VERSION=">=1.19.4"
|
||||
NUMPY_DEP_VERSION=">=1.19.4,<=1.19.4"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
@ -57,7 +57,7 @@ runs:
|
||||
;;
|
||||
3.9*)
|
||||
NUMPY_BUILD_VERSION="==1.19.4"
|
||||
NUMPY_DEP_VERSION=">=1.19.4"
|
||||
NUMPY_DEP_VERSION=">=1.19.4,<=1.19.4"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
@ -82,7 +82,7 @@ runs:
|
||||
;;
|
||||
3.9*)
|
||||
NUMPY_BUILD_VERSION="==1.19.4"
|
||||
NUMPY_DEP_VERSION=">=1.19.4"
|
||||
NUMPY_DEP_VERSION=">=1.19.4,<=1.19.4"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
|
24
.github/actions/python-build/action.yml
vendored
24
.github/actions/python-build/action.yml
vendored
@ -1,9 +1,6 @@
|
||||
name: "Python binding"
|
||||
description: "Binding a python binding"
|
||||
inputs:
|
||||
build_flavor:
|
||||
description: "Python package name"
|
||||
required: true
|
||||
numpy_build:
|
||||
description: "NumPy build dependecy"
|
||||
required: true
|
||||
@ -36,22 +33,15 @@ runs:
|
||||
- run: |
|
||||
python3 --version
|
||||
pip3 --version
|
||||
python3 -m pip install virtualenv
|
||||
python3 -m virtualenv stt-build
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p wheels
|
||||
shell: bash
|
||||
- run: |
|
||||
set -xe
|
||||
|
||||
PROJECT_NAME="stt"
|
||||
if [ "${{ inputs.build_flavor }}" = "tflite" ]; then
|
||||
PROJECT_NAME="stt-tflite"
|
||||
fi
|
||||
|
||||
OS=$(uname)
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
if [ "${OS}" = "Linux" -a "${{ inputs.target }}" != "host" ]; then
|
||||
python3 -m venv stt-build
|
||||
source stt-build/bin/activate
|
||||
fi
|
||||
|
||||
@ -65,14 +55,4 @@ runs:
|
||||
RASPBIAN=${{ inputs.chroot }} \
|
||||
SETUP_FLAGS="--project_name ${PROJECT_NAME}" \
|
||||
bindings-clean bindings
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
deactivate
|
||||
fi
|
||||
shell: bash
|
||||
- run: |
|
||||
cp native_client/python/dist/*.whl wheels
|
||||
shell: bash
|
||||
- run: |
|
||||
make -C native_client/python/ bindings-clean
|
||||
shell: bash
|
||||
|
8
.github/actions/run-tests/action.yml
vendored
8
.github/actions/run-tests/action.yml
vendored
@ -4,9 +4,6 @@ inputs:
|
||||
runtime:
|
||||
description: "Runtime to use for running test"
|
||||
required: true
|
||||
build-flavor:
|
||||
description: "Running against TF or TFLite"
|
||||
required: true
|
||||
model-kind:
|
||||
description: "Running against CI baked or production model"
|
||||
required: true
|
||||
@ -22,10 +19,7 @@ runs:
|
||||
- run: |
|
||||
set -xe
|
||||
|
||||
build=""
|
||||
if [ "${{ inputs.build-flavor }}" = "tflite" ]; then
|
||||
build="_tflite"
|
||||
fi
|
||||
build="_tflite"
|
||||
|
||||
model_kind=""
|
||||
if [ "${{ inputs.model-kind }}" = "prod" ]; then
|
||||
|
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@ -5,3 +5,11 @@ Welcome to the 🐸STT project! We are excited to see your interest, and appreci
|
||||
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
|
||||
|
||||
In order to make a good pull request, please see our [CONTRIBUTING.rst](CONTRIBUTING.rst) file, in particular make sure you have set-up and run the pre-commit hook to check your changes for code style violations.
|
||||
|
||||
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/STT).
|
||||
|
||||
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/STT):
|
||||
|
||||
- Protects you, Coqui, and the users of the code.
|
||||
- Does not change your rights to use your contributions for any purpose.
|
||||
- Does not change the license of the 🐸STT project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.
|
||||
|
703
.github/workflows/build-and-test.yml
vendored
703
.github/workflows/build-and-test.yml
vendored
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
exclude: '^(taskcluster|.github|native_client/kenlm|native_client/ctcdecode/third_party|tensorflow|kenlm|doc/examples|data/alphabet.txt)'
|
||||
exclude: '^(taskcluster|.github|native_client/kenlm|native_client/ctcdecode/third_party|tensorflow|kenlm|doc/examples|data/alphabet.txt|data/smoke_test)'
|
||||
repos:
|
||||
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
|
||||
rev: v2.3.0
|
||||
|
100
Dockerfile.train
100
Dockerfile.train
@ -1,39 +1,73 @@
|
||||
# Please refer to the TRAINING documentation, "Basic Dockerfile for training"
|
||||
# This is a Dockerfile useful for training models with Coqui STT.
|
||||
# You can train "acoustic models" with audio + Tensorflow, and
|
||||
# you can create "scorers" with text + KenLM.
|
||||
|
||||
FROM nvcr.io/nvidia/tensorflow:21.05-tf1-py3
|
||||
FROM ubuntu:20.04 AS kenlm-build
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# We need to purge python3-xdg because it's breaking STT install later with
|
||||
# weird errors about setuptools
|
||||
#
|
||||
# libopus0 and libsndfile1 are dependencies for audio augmentation
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential cmake libboost-system-dev \
|
||||
libboost-thread-dev libboost-program-options-dev \
|
||||
libboost-test-dev libeigen3-dev zlib1g-dev \
|
||||
libbz2-dev liblzma-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Build KenLM to generate new scorers
|
||||
WORKDIR /code
|
||||
COPY kenlm /code/kenlm
|
||||
RUN cd /code/kenlm && \
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake .. && \
|
||||
make -j $(nproc) || \
|
||||
( echo "ERROR: Failed to build KenLM."; \
|
||||
echo "ERROR: Make sure you update the kenlm submodule on host before building this Dockerfile."; \
|
||||
echo "ERROR: $ cd STT; git submodule update --init kenlm"; \
|
||||
exit 1; )
|
||||
|
||||
|
||||
FROM ubuntu:20.04 AS wget-binaries
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends wget unzip xz-utils && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Tool to convert output graph for inference
|
||||
RUN wget --no-check-certificate https://github.com/coqui-ai/STT/releases/download/v0.9.3/convert_graphdef_memmapped_format.linux.amd64.zip -O temp.zip && \
|
||||
unzip temp.zip && \
|
||||
rm temp.zip
|
||||
|
||||
RUN wget --no-check-certificate https://github.com/reuben/STT/releases/download/v0.10.0-alpha.1/native_client.tar.xz -O temp.tar.xz && \
|
||||
tar -xf temp.tar.xz && \
|
||||
rm temp.tar.xz
|
||||
|
||||
|
||||
FROM nvcr.io/nvidia/tensorflow:20.06-tf1-py3
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# We need to purge python3-xdg because
|
||||
# it's breaking STT install later with
|
||||
# errors about setuptools
|
||||
#
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
cmake \
|
||||
curl \
|
||||
git \
|
||||
libboost-all-dev \
|
||||
libbz2-dev \
|
||||
libopus0 \
|
||||
libopusfile0 \
|
||||
libsndfile1 \
|
||||
unzip \
|
||||
wget \
|
||||
sox && \
|
||||
libopus0 \
|
||||
libopusfile0 \
|
||||
libsndfile1 \
|
||||
sox \
|
||||
libsox-fmt-mp3 && \
|
||||
apt-get purge -y python3-xdg && \
|
||||
rm -rf /var/lib/apt/lists/
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Make sure pip and its deps are up-to-date
|
||||
# Make sure pip and its dependencies are up-to-date
|
||||
RUN pip3 install --upgrade pip wheel setuptools
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
# Tool to convert output graph for inference
|
||||
RUN wget https://github.com/coqui-ai/STT/releases/download/v0.9.3/convert_graphdef_memmapped_format.linux.amd64.zip -O temp.zip && \
|
||||
unzip temp.zip && rm temp.zip
|
||||
|
||||
COPY native_client /code/native_client
|
||||
COPY .git /code/.git
|
||||
COPY training/coqui_stt_training/VERSION /code/training/coqui_stt_training/VERSION
|
||||
@ -43,22 +77,20 @@ COPY training/coqui_stt_training/GRAPH_VERSION /code/training/coqui_stt_training
|
||||
RUN cd native_client/ctcdecode && make NUM_PROCESSES=$(nproc) bindings
|
||||
RUN pip3 install --upgrade native_client/ctcdecode/dist/*.whl
|
||||
|
||||
# Install STT
|
||||
# - No need for the decoder since we did it earlier
|
||||
# - There is already correct TensorFlow GPU installed on the base image,
|
||||
# we don't want to break that
|
||||
COPY setup.py /code/setup.py
|
||||
COPY VERSION /code/VERSION
|
||||
COPY training /code/training
|
||||
RUN DS_NODECODER=y DS_NOTENSORFLOW=y pip3 install --upgrade -e .
|
||||
# Copy files from previous build stages
|
||||
RUN mkdir -p /code/kenlm/build/
|
||||
COPY --from=kenlm-build /code/kenlm/build/bin /code/kenlm/build/bin
|
||||
COPY --from=wget-binaries /convert_graphdef_memmapped_format /code/convert_graphdef_memmapped_format
|
||||
COPY --from=wget-binaries /generate_scorer_package /code/generate_scorer_package
|
||||
|
||||
# Build KenLM to generate new scorers
|
||||
COPY kenlm /code/kenlm
|
||||
RUN cd /code/kenlm && \
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake .. && \
|
||||
make -j $(nproc)
|
||||
# Install STT
|
||||
# No need for the decoder since we did it earlier
|
||||
# TensorFlow GPU should already be installed on the base image,
|
||||
# and we don't want to break that
|
||||
RUN DS_NODECODER=y DS_NOTENSORFLOW=y pip3 install --upgrade -e .
|
||||
|
||||
# Copy rest of the code and test training
|
||||
COPY . /code
|
||||
|
12
Dockerfile.train.jupyter
Normal file
12
Dockerfile.train.jupyter
Normal file
@ -0,0 +1,12 @@
|
||||
# This is a Dockerfile useful for training models with Coqui STT in Jupyter notebooks
|
||||
|
||||
FROM ghcr.io/coqui-ai/stt-train:latest
|
||||
|
||||
WORKDIR /code/notebooks
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir jupyter jupyter_http_over_ws
|
||||
RUN jupyter serverextension enable --py jupyter_http_over_ws
|
||||
|
||||
EXPOSE 8888
|
||||
|
||||
CMD ["bash", "-c", "jupyter notebook --notebook-dir=/code/notebooks --ip 0.0.0.0 --no-browser --allow-root"]
|
2
MANIFEST.in
Normal file
2
MANIFEST.in
Normal file
@ -0,0 +1,2 @@
|
||||
include training/coqui_stt_training/VERSION
|
||||
include training/coqui_stt_training/GRAPH_VERSION
|
@ -14,7 +14,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--scorer "" \
|
||||
--augment dropout \
|
||||
|
@ -14,7 +14,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
|
@ -20,7 +20,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||
|
@ -17,7 +17,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--feature_cache '/tmp/ldc93s1_cache' \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
@ -26,4 +27,5 @@ python -u train.py --show_progressbar false --early_stop false \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate}
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--export_tflite false
|
||||
|
@ -27,4 +27,5 @@ python -u train.py --show_progressbar false --early_stop false \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_bytes' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.bytes.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--bytes_output_mode true
|
||||
--bytes_output_mode true \
|
||||
--export_tflite false
|
||||
|
@ -17,7 +17,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
|
@ -23,7 +23,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||
|
@ -23,7 +23,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_sdb} ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
|
||||
--dev_files ${ldc93s1_sdb} ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
|
@ -14,7 +14,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
@ -23,7 +24,7 @@ python -u train.py --show_progressbar false --early_stop false \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer'
|
||||
|
||||
python -u train.py \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
|
@ -16,7 +16,8 @@ fi;
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite' \
|
||||
@ -26,7 +27,8 @@ python -u train.py --show_progressbar false \
|
||||
|
||||
mkdir /tmp/train_tflite/en-us
|
||||
|
||||
python -u train.py --show_progressbar false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite/en-us' \
|
||||
|
25
bin/run-ldc93s1.py
Executable file
25
bin/run-ldc93s1.py
Executable file
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
||||
from coqui_stt_training.util.config import initialize_globals_from_args
|
||||
from coqui_stt_training.train import train
|
||||
from coqui_stt_training.evaluate import test
|
||||
|
||||
# only one GPU for only one training sample
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
download_ldc("data/ldc93s1")
|
||||
|
||||
initialize_globals_from_args(
|
||||
load_train="init",
|
||||
alphabet_config_path="data/alphabet.txt",
|
||||
train_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
dev_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
test_files=["data/ldc93s1/ldc93s1.csv"],
|
||||
augment=["time_mask"],
|
||||
n_hidden=100,
|
||||
epochs=200,
|
||||
)
|
||||
|
||||
train()
|
||||
test()
|
@ -20,7 +20,8 @@ fi
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --show_progressbar false \
|
||||
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||
--show_progressbar false \
|
||||
--train_files data/ldc93s1/ldc93s1.csv \
|
||||
--test_files data/ldc93s1/ldc93s1.csv \
|
||||
--train_batch_size 1 \
|
||||
|
@ -55,23 +55,6 @@ maybe_install_xldd()
|
||||
fi
|
||||
}
|
||||
|
||||
# Checks whether we run a patched version of bazel.
|
||||
# Patching is required to dump computeKey() parameters to .ckd files
|
||||
# See bazel.patch
|
||||
# Return 0 (success exit code) on patched version, 1 on release version
|
||||
is_patched_bazel()
|
||||
{
|
||||
bazel_version=$(bazel version | grep 'Build label:' | cut -d':' -f2)
|
||||
|
||||
bazel shutdown
|
||||
|
||||
if [ -z "${bazel_version}" ]; then
|
||||
return 0;
|
||||
else
|
||||
return 1;
|
||||
fi;
|
||||
}
|
||||
|
||||
verify_bazel_rebuild()
|
||||
{
|
||||
bazel_explain_file="$1"
|
||||
|
@ -9,21 +9,14 @@ do_bazel_build()
|
||||
cd ${DS_TFDIR}
|
||||
eval "export ${BAZEL_ENV_FLAGS}"
|
||||
|
||||
if [ "${_opt_or_dbg}" = "opt" ]; then
|
||||
if is_patched_bazel; then
|
||||
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-tf.tar -T -
|
||||
fi;
|
||||
fi;
|
||||
|
||||
bazel ${BAZEL_OUTPUT_USER_ROOT} build \
|
||||
-s --explain bazel_monolithic.log --verbose_explanations --experimental_strict_action_env --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
|
||||
-s --explain bazel_explain.log --verbose_explanations \
|
||||
--workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" \
|
||||
-c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
|
||||
|
||||
if [ "${_opt_or_dbg}" = "opt" ]; then
|
||||
if is_patched_bazel; then
|
||||
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-ds.tar -T -
|
||||
fi;
|
||||
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_monolithic.log"
|
||||
fi;
|
||||
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_explain.log"
|
||||
fi
|
||||
}
|
||||
|
||||
shutdown_bazel()
|
||||
|
@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_material
|
||||
|
||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
||||
|
||||
check_versions
|
||||
|
||||
run_prod_inference_tests "${bitrate}"
|
@ -1,24 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
||||
|
||||
check_versions
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_multi_inference_tests
|
||||
|
||||
run_cpp_only_inference_tests
|
||||
|
||||
run_hotword_tests
|
@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_material "${CI_TMP_DIR}/ds"
|
||||
|
||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
||||
|
||||
check_versions
|
||||
|
||||
ensure_cuda_usage "$2"
|
||||
|
||||
run_basic_inference_tests
|
@ -1,48 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
symlink_electron
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
which electron
|
||||
which node
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export DISPLAY=':99.0'
|
||||
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
|
||||
xvfb_process=$!
|
||||
fi
|
||||
|
||||
node --version
|
||||
|
||||
stt --version
|
||||
|
||||
check_runtime_electronjs
|
||||
|
||||
run_electronjs_prod_inference_tests "${bitrate}"
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
sleep 1
|
||||
sudo kill -9 ${xvfb_process} || true
|
||||
fi
|
@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
symlink_electron
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
which electron
|
||||
which node
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export DISPLAY=':99.0'
|
||||
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
|
||||
xvfb_process=$!
|
||||
fi
|
||||
|
||||
node --version
|
||||
|
||||
stt --version
|
||||
|
||||
check_runtime_electronjs
|
||||
|
||||
run_electronjs_inference_tests
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
sleep 1
|
||||
sudo kill -9 ${xvfb_process} || true
|
||||
fi
|
@ -2,8 +2,6 @@
|
||||
|
||||
set -xe
|
||||
|
||||
runtime=$1
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/build-utils.sh
|
||||
@ -15,10 +13,7 @@ BAZEL_TARGETS="
|
||||
//native_client:generate_scorer_package
|
||||
"
|
||||
|
||||
if [ "${runtime}" = "tflite" ]; then
|
||||
BAZEL_BUILD_TFLITE="--define=runtime=tflite"
|
||||
fi;
|
||||
BAZEL_BUILD_FLAGS="${BAZEL_BUILD_TFLITE} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||
BAZEL_BUILD_FLAGS="${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||
|
||||
BAZEL_ENV_FLAGS="TF_NEED_CUDA=0"
|
||||
SYSTEM_TARGET=host
|
||||
|
@ -1,30 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
check_runtime_nodejs
|
||||
|
||||
run_prod_inference_tests "${bitrate}"
|
||||
|
||||
run_js_streaming_prod_inference_tests "${bitrate}"
|
@ -1,25 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
check_runtime_nodejs
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_js_streaming_inference_tests
|
||||
|
||||
run_hotword_tests
|
@ -26,9 +26,26 @@ package_native_client()
|
||||
win_lib="-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so.if.lib"
|
||||
fi;
|
||||
|
||||
if [ -f "${tensorflow_dir}/bazel-bin/native_client/libkenlm.so.if.lib" ]; then
|
||||
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib"
|
||||
fi;
|
||||
|
||||
if [ -f "${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib" ]; then
|
||||
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so.if.lib"
|
||||
fi;
|
||||
|
||||
libsox_lib=""
|
||||
if [ -f "${stt_dir}/sox-build/lib/libsox.so.3" ]; then
|
||||
libsox_lib="-C ${stt_dir}/sox-build/lib libsox.so.3"
|
||||
fi
|
||||
|
||||
${TAR} --verbose -cf - \
|
||||
--transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \
|
||||
-C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so \
|
||||
${win_lib} \
|
||||
${libsox_lib} \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
|
||||
-C ${stt_dir}/ LICENSE \
|
||||
-C ${stt_dir}/native_client/ stt${PLATFORM_EXE_SUFFIX} \
|
||||
@ -74,6 +91,7 @@ package_native_client_ndk()
|
||||
package_libstt_as_zip()
|
||||
{
|
||||
tensorflow_dir=${DS_TFDIR}
|
||||
stt_dir=${DS_DSDIR}
|
||||
artifacts_dir=${CI_ARTIFACTS_DIR}
|
||||
artifact_name=$1
|
||||
|
||||
@ -88,5 +106,14 @@ package_libstt_as_zip()
|
||||
echo "Please specify artifact name."
|
||||
fi;
|
||||
|
||||
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" ${tensorflow_dir}/bazel-bin/native_client/libstt.so
|
||||
libsox_lib=""
|
||||
if [ -f "${stt_dir}/sox-build/lib/libsox.so.3" ]; then
|
||||
libsox_lib="${stt_dir}/sox-build/lib/libsox.so.3"
|
||||
fi
|
||||
|
||||
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" \
|
||||
${tensorflow_dir}/bazel-bin/native_client/libstt.so \
|
||||
${tensorflow_dir}/bazel-bin/native_client/libkenlm.so \
|
||||
${libsox_lib} \
|
||||
${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so
|
||||
}
|
||||
|
@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_material
|
||||
|
||||
export_py_bin_path
|
||||
|
||||
which stt
|
||||
stt --version
|
||||
|
||||
run_prod_inference_tests "${bitrate}"
|
||||
|
||||
run_prod_concurrent_stream_tests "${bitrate}"
|
@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
export_py_bin_path
|
||||
|
||||
which stt
|
||||
stt --version
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_hotword_tests
|
@ -6,30 +6,20 @@ set -o pipefail
|
||||
source $(dirname $0)/tf-vars.sh
|
||||
|
||||
pushd ${DS_ROOT_TASK}/tensorflow/
|
||||
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s --explain bazel_monolithic_tf.log --verbose_explanations --experimental_strict_action_env --config=monolithic"
|
||||
|
||||
# Start a bazel process to ensure reliability on Windows and avoid:
|
||||
# FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing.
|
||||
bazel ${BAZEL_OUTPUT_USER_ROOT} info
|
||||
|
||||
# Force toolchain sync (useful on macOS ?)
|
||||
bazel ${BAZEL_OUTPUT_USER_ROOT} sync --configure
|
||||
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s"
|
||||
|
||||
MAYBE_DEBUG=$2
|
||||
OPT_OR_DBG="-c opt"
|
||||
if [ "${MAYBE_DEBUG}" = "dbg" ]; then
|
||||
OPT_OR_DBG="-c dbg"
|
||||
OPT_OR_DBG="-c dbg"
|
||||
fi;
|
||||
|
||||
case "$1" in
|
||||
"--windows-cpu")
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIBSTT} ${BUILD_TARGET_LITE_LIB} --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh"
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
"--linux-cpu"|"--darwin-cpu")
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
"--linux-cuda"|"--windows-cuda")
|
||||
eval "export ${TF_CUDA_FLAGS}" && (echo "" | TF_NEED_CUDA=1 ./configure) && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_CUDA_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API}
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
"--linux-armv7")
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
@ -50,6 +40,4 @@ pushd ${DS_ROOT_TASK}/tensorflow/
|
||||
echo "" | TF_NEED_CUDA=0 TF_CONFIGURE_IOS=1 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_IOS_X86_64_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
esac
|
||||
|
||||
bazel ${BAZEL_OUTPUT_USER_ROOT} shutdown
|
||||
popd
|
||||
|
@ -6,26 +6,17 @@ source $(dirname $0)/tf-vars.sh
|
||||
|
||||
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
||||
|
||||
cp ${DS_ROOT_TASK}/tensorflow/bazel_*.log ${CI_ARTIFACTS_DIR} || true
|
||||
|
||||
OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin"
|
||||
|
||||
for output_bin in \
|
||||
tensorflow/lite/experimental/c/libtensorflowlite_c.so \
|
||||
tensorflow/tools/graph_transforms/transform_graph \
|
||||
tensorflow/tools/graph_transforms/summarize_graph \
|
||||
tensorflow/tools/benchmark/benchmark_model \
|
||||
tensorflow/contrib/util/convert_graphdef_memmapped_format \
|
||||
tensorflow/lite/toco/toco;
|
||||
for output_bin in \
|
||||
tensorflow/lite/libtensorflow.so \
|
||||
tensorflow/lite/libtensorflow.so.if.lib \
|
||||
;
|
||||
do
|
||||
if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then
|
||||
cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/
|
||||
fi;
|
||||
done;
|
||||
|
||||
if [ -f "${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model" ]; then
|
||||
cp ${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model ${CI_ARTIFACTS_DIR}/lite_benchmark_model
|
||||
fi
|
||||
done
|
||||
|
||||
# It seems that bsdtar and gnutar are behaving a bit differently on the way
|
||||
# they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/
|
||||
|
@ -5,12 +5,7 @@ set -ex
|
||||
source $(dirname $0)/tf-vars.sh
|
||||
|
||||
install_android=
|
||||
install_cuda=
|
||||
case "$1" in
|
||||
"--linux-cuda"|"--windows-cuda")
|
||||
install_cuda=yes
|
||||
;;
|
||||
|
||||
"--android-armv7"|"--android-arm64")
|
||||
install_android=yes
|
||||
;;
|
||||
@ -22,18 +17,13 @@ download()
|
||||
{
|
||||
fname=`basename $1`
|
||||
|
||||
${WGET} $1 -O ${DS_ROOT_TASK}/dls/$fname && echo "$2 ${DS_ROOT_TASK}/dls/$fname" | ${SHA_SUM} -
|
||||
${CURL} -sSL -o ${DS_ROOT_TASK}/dls/$fname $1 && echo "$2 ${DS_ROOT_TASK}/dls/$fname" | ${SHA_SUM} -
|
||||
}
|
||||
|
||||
# Download stuff
|
||||
mkdir -p ${DS_ROOT_TASK}/dls || true
|
||||
download $BAZEL_URL $BAZEL_SHA256
|
||||
|
||||
if [ ! -z "${install_cuda}" ]; then
|
||||
download $CUDA_URL $CUDA_SHA256
|
||||
download $CUDNN_URL $CUDNN_SHA256
|
||||
fi;
|
||||
|
||||
if [ ! -z "${install_android}" ]; then
|
||||
download $ANDROID_NDK_URL $ANDROID_NDK_SHA256
|
||||
download $ANDROID_SDK_URL $ANDROID_SDK_SHA256
|
||||
@ -44,49 +34,21 @@ ls -hal ${DS_ROOT_TASK}/dls/
|
||||
|
||||
# Install Bazel in ${DS_ROOT_TASK}/bin
|
||||
BAZEL_INSTALL_FILENAME=$(basename "${BAZEL_URL}")
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
BAZEL_INSTALL_FLAGS="--user"
|
||||
elif [ "${OS}" = "Darwin" ]; then
|
||||
BAZEL_INSTALL_FLAGS="--bin=${DS_ROOT_TASK}/bin --base=${DS_ROOT_TASK}/.bazel"
|
||||
fi;
|
||||
mkdir -p ${DS_ROOT_TASK}/bin || true
|
||||
pushd ${DS_ROOT_TASK}/bin
|
||||
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
||||
cp ${DS_ROOT_TASK}/dls/${BAZEL_INSTALL_FILENAME} ${DS_ROOT_TASK}/bin/bazel.exe
|
||||
else
|
||||
/bin/bash ${DS_ROOT_TASK}/dls/${BAZEL_INSTALL_FILENAME} ${BAZEL_INSTALL_FLAGS}
|
||||
fi
|
||||
popd
|
||||
|
||||
SUFFIX=""
|
||||
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
||||
SUFFIX=".exe"
|
||||
fi
|
||||
|
||||
cp ${DS_ROOT_TASK}/dls/${BAZEL_INSTALL_FILENAME} ${DS_ROOT_TASK}/bin/bazel${SUFFIX}
|
||||
chmod +x ${DS_ROOT_TASK}/bin/bazel${SUFFIX}
|
||||
|
||||
# For debug
|
||||
bazel version
|
||||
|
||||
bazel shutdown
|
||||
|
||||
if [ ! -z "${install_cuda}" ]; then
|
||||
# Install CUDA and CuDNN
|
||||
mkdir -p ${DS_ROOT_TASK}/STT/CUDA/ || true
|
||||
pushd ${DS_ROOT_TASK}
|
||||
CUDA_FILE=`basename ${CUDA_URL}`
|
||||
PERL5LIB=. sh ${DS_ROOT_TASK}/dls/${CUDA_FILE} --silent --override --toolkit --toolkitpath=${DS_ROOT_TASK}/STT/CUDA/ --defaultroot=${DS_ROOT_TASK}/STT/CUDA/
|
||||
|
||||
CUDNN_FILE=`basename ${CUDNN_URL}`
|
||||
tar xvf ${DS_ROOT_TASK}/dls/${CUDNN_FILE} --strip-components=1 -C ${DS_ROOT_TASK}/STT/CUDA/
|
||||
popd
|
||||
|
||||
LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
|
||||
export LD_LIBRARY_PATH
|
||||
|
||||
# We might lack libcuda.so.1 symlink, let's fix as upstream does:
|
||||
# https://github.com/tensorflow/tensorflow/pull/13811/files?diff=split#diff-2352449eb75e66016e97a591d3f0f43dR96
|
||||
if [ ! -h "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1" ]; then
|
||||
ln -s "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so" "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1"
|
||||
fi;
|
||||
|
||||
else
|
||||
echo "No CUDA/CuDNN to install"
|
||||
fi
|
||||
|
||||
if [ ! -z "${install_android}" ]; then
|
||||
mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true
|
||||
ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}`
|
||||
@ -105,8 +67,3 @@ if [ ! -z "${install_android}" ]; then
|
||||
fi
|
||||
|
||||
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
||||
|
||||
|
||||
# Taken from https://www.tensorflow.org/install/source
|
||||
# Only future is needed for our builds, as we don't build the Python package
|
||||
python -m pip install -U --user future==0.17.1 || true
|
||||
|
@ -6,15 +6,8 @@ export OS=$(uname)
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export DS_ROOT_TASK=${CI_TASK_DIR}
|
||||
|
||||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh
|
||||
BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed
|
||||
|
||||
CUDA_URL=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
|
||||
CUDA_SHA256=e7c22dc21278eb1b82f34a60ad7640b41ad3943d929bebda3008b72536855d31
|
||||
|
||||
# From https://gitlab.com/nvidia/cuda/blob/centos7/10.1/devel/cudnn7/Dockerfile
|
||||
CUDNN_URL=http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.1-linux-x64-v7.6.0.64.tgz
|
||||
CUDNN_SHA256=e956c6f9222fcb867a10449cfc76dee5cfd7c7531021d95fe9586d7e043b57d7
|
||||
BAZEL_URL=https://github.com/bazelbuild/bazelisk/releases/download/v1.10.1/bazelisk-linux-amd64
|
||||
BAZEL_SHA256=4cb534c52cdd47a6223d4596d530e7c9c785438ab3b0a49ff347e991c210b2cd
|
||||
|
||||
ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip
|
||||
ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd
|
||||
@ -45,10 +38,8 @@ elif [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
||||
export TEMP=${CI_TASK_DIR}/tmp/
|
||||
export TMP=${CI_TASK_DIR}/tmp/
|
||||
|
||||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe
|
||||
BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9
|
||||
|
||||
CUDA_INSTALL_DIRECTORY=$(cygpath 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1')
|
||||
BAZEL_URL=https://github.com/bazelbuild/bazelisk/releases/download/v1.10.1/bazelisk-windows-amd64.exe
|
||||
BAZEL_SHA256=9a89e6a8cc0a3aea37affcf8c146d8925ffbda1d2290c0c6a845ea81e05de62c
|
||||
|
||||
TAR=/usr/bin/tar.exe
|
||||
elif [ "${OS}" = "Darwin" ]; then
|
||||
@ -61,14 +52,15 @@ elif [ "${OS}" = "Darwin" ]; then
|
||||
|
||||
export DS_ROOT_TASK=${CI_TASK_DIR}
|
||||
|
||||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-darwin-x86_64.sh
|
||||
BAZEL_SHA256=5cfa97031b43432b3c742c80e2e01c41c0acdca7ba1052fc8cf1e291271bc9cd
|
||||
BAZEL_URL=https://github.com/bazelbuild/bazelisk/releases/download/v1.10.1/bazelisk-darwin-amd64
|
||||
BAZEL_SHA256=e485bbf84532d02a60b0eb23c702610b5408df3a199087a4f2b5e0995bbf2d5a
|
||||
|
||||
SHA_SUM="shasum -a 256 -c"
|
||||
TAR=gtar
|
||||
fi;
|
||||
|
||||
WGET=${WGET:-"wget"}
|
||||
CURL=${CURL:-"curl"}
|
||||
TAR=${TAR:-"tar"}
|
||||
XZ=${XZ:-"xz -9 -T0"}
|
||||
ZIP=${ZIP:-"zip"}
|
||||
@ -89,7 +81,6 @@ fi;
|
||||
export PATH
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
|
||||
export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/
|
||||
export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/
|
||||
fi;
|
||||
@ -120,8 +111,8 @@ export GCC_HOST_COMPILER_PATH=/usr/bin/gcc
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
source /etc/os-release
|
||||
if [ "${ID}" = "ubuntu" -a "${VERSION_ID}" = "20.04" ]; then
|
||||
export PYTHON_BIN_PATH=/usr/bin/python3
|
||||
if [ "${ID}" = "debian" -a "${VERSION_ID}" = "9" ]; then
|
||||
export PYTHON_BIN_PATH=/opt/python/cp37-cp37m/bin/python
|
||||
fi
|
||||
elif [ "${OS}" != "${TC_MSYS_VERSION}" ]; then
|
||||
export PYTHON_BIN_PATH=python
|
||||
@ -160,27 +151,15 @@ export BAZEL_OUTPUT_USER_ROOT
|
||||
|
||||
NVCC_COMPUTE="3.5"
|
||||
|
||||
### Define build parameters/env variables that we will re-ues in sourcing scripts.
|
||||
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
||||
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_PATHS=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
|
||||
else
|
||||
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_PATHS=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
|
||||
fi
|
||||
BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_CUDA_FLAGS="--config=cuda"
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
# constexpr usage in tensorflow's absl dep fails badly because of gcc-5
|
||||
# so let's skip that
|
||||
BAZEL_CUDA_FLAGS="${BAZEL_CUDA_FLAGS} --copt=-DNO_CONSTEXPR_FOR_YOU=1"
|
||||
fi
|
||||
BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
|
||||
if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then
|
||||
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl --copt=-fvisibility=hidden"
|
||||
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl"
|
||||
fi
|
||||
|
||||
if [ "${OS}" = "Darwin" ]; then
|
||||
@ -189,11 +168,5 @@ fi
|
||||
|
||||
### Define build targets that we will re-ues in sourcing scripts.
|
||||
BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc"
|
||||
BUILD_TARGET_GRAPH_TRANSFORMS="//tensorflow/tools/graph_transforms:transform_graph"
|
||||
BUILD_TARGET_GRAPH_SUMMARIZE="//tensorflow/tools/graph_transforms:summarize_graph"
|
||||
BUILD_TARGET_GRAPH_BENCHMARK="//tensorflow/tools/benchmark:benchmark_model"
|
||||
#BUILD_TARGET_CONVERT_MMAP="//tensorflow/contrib/util:convert_graphdef_memmapped_format"
|
||||
BUILD_TARGET_TOCO="//tensorflow/lite/toco:toco"
|
||||
BUILD_TARGET_LITE_BENCHMARK="//tensorflow/lite/tools/benchmark:benchmark_model"
|
||||
BUILD_TARGET_LITE_LIB="//tensorflow/lite/c:libtensorflowlite_c.so"
|
||||
BUILD_TARGET_LITE_LIB="//tensorflow/lite:libtensorflowlite.so"
|
||||
BUILD_TARGET_LIBSTT="//native_client:libstt.so"
|
||||
|
@ -66,3 +66,6 @@ time ./bin/run-ci-ldc93s1_checkpoint_sdb.sh
|
||||
|
||||
# Bytes output mode, resuming from checkpoint
|
||||
time ./bin/run-ci-ldc93s1_checkpoint_bytes.sh
|
||||
|
||||
# Training with args set via initialize_globals_from_args()
|
||||
time python ./bin/run-ldc93s1.py
|
||||
|
@ -1,4 +1,4 @@
|
||||
|
||||
|
||||
о
|
||||
е
|
||||
а
|
||||
|
@ -7,65 +7,46 @@ Here we maintain the list of supported platforms for deployment.
|
||||
|
||||
*Note that 🐸STT currently only provides packages for CPU deployment with Python 3.5 or higher on Linux. We're working to get the rest of our usually supported packages back up and running as soon as possible.*
|
||||
|
||||
Linux / AMD64 without GPU
|
||||
Linux / AMD64
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
|
||||
* Ubuntu 14.04+ (glibc >= 2.19, libstdc++6 >= 4.8)
|
||||
* Full TensorFlow runtime (``stt`` packages)
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
|
||||
Linux / AMD64 with GPU
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
|
||||
* Ubuntu 14.04+ (glibc >= 2.19, libstdc++6 >= 4.8)
|
||||
* CUDA 10.0 (and capable GPU)
|
||||
* Full TensorFlow runtime (``stt`` packages)
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* glibc >= 2.24, libstdc++6 >= 6.3
|
||||
* TensorFlow Lite runtime
|
||||
|
||||
Linux / ARMv7
|
||||
^^^^^^^^^^^^^
|
||||
* Cortex-A53 compatible ARMv7 SoC with Neon support
|
||||
* Raspbian Buster-compatible distribution
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* TensorFlow Lite runtime
|
||||
|
||||
Linux / Aarch64
|
||||
^^^^^^^^^^^^^^^
|
||||
* Cortex-A72 compatible Aarch64 SoC
|
||||
* ARMbian Buster-compatible distribution
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* TensorFlow Lite runtime
|
||||
|
||||
Android / ARMv7
|
||||
^^^^^^^^^^^^^^^
|
||||
* ARMv7 SoC with Neon support
|
||||
* Android 7.0-10.0
|
||||
* NDK API level >= 21
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* TensorFlow Lite runtime
|
||||
|
||||
Android / Aarch64
|
||||
^^^^^^^^^^^^^^^^^
|
||||
* Aarch64 SoC
|
||||
* Android 7.0-10.0
|
||||
* NDK API level >= 21
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* TensorFlow Lite runtime
|
||||
|
||||
macOS / AMD64
|
||||
^^^^^^^^^^^^^
|
||||
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
|
||||
* macOS >= 10.10
|
||||
* Full TensorFlow runtime (``stt`` packages)
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* TensorFlow Lite runtime
|
||||
|
||||
Windows / AMD64 without GPU
|
||||
Windows / AMD64
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
|
||||
* Windows Server >= 2012 R2 ; Windows >= 8.1
|
||||
* Full TensorFlow runtime (``stt`` packages)
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
|
||||
Windows / AMD64 with GPU
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
|
||||
* Windows Server >= 2012 R2 ; Windows >= 8.1
|
||||
* CUDA 10.0 (and capable GPU)
|
||||
* Full TensorFlow runtime (``stt`` packages)
|
||||
* TensorFlow Lite runtime (``stt-tflite`` packages)
|
||||
* TensorFlow Lite runtime
|
||||
|
@ -8,7 +8,7 @@ Below you can find the definition of all command-line flags supported by the tra
|
||||
Flags
|
||||
-----
|
||||
|
||||
.. literalinclude:: ../training/coqui_stt_training/util/flags.py
|
||||
.. literalinclude:: ../training/coqui_stt_training/util/config.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lineno-match:
|
||||
|
@ -19,6 +19,8 @@ Coqui STT
|
||||
|
||||
TRAINING_INTRO
|
||||
|
||||
TRAINING_ADVANCED
|
||||
|
||||
BUILDING
|
||||
|
||||
Quickstart: Deployment
|
||||
|
@ -39,7 +39,7 @@ Numbers should be written in full (ie as a [cardinal](https://en.wikipedia.org/w
|
||||
|
||||
### Data from Common Voice
|
||||
|
||||
If you are using data from Common Voice for training a model, you will need to prepare it as [outlined in the 🐸STT documentation](https://stt.readthedocs.io/en/latest/TRAINING.html#common-voice-training-data).
|
||||
If you are using data from Common Voice for training a model, you will need to prepare it as [outlined in the 🐸STT documentation](https://stt.readthedocs.io/en/latest/COMMON_VOICE_DATA.html#common-voice-data).
|
||||
|
||||
In this example we will prepare the Indonesian dataset for training, but you can use any language from Common Voice that you prefer. We've chosen Indonesian as it has the same [orthographic alphabet](ALPHABET.md) as English, which means we don't have to use a different `alphabet.txt` file for training; we can use the default.
|
||||
|
||||
|
@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function
|
||||
|
||||
import sys
|
||||
|
||||
import absl.app
|
||||
import optuna
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from coqui_stt_ctcdecoder import Scorer
|
||||
from coqui_stt_training.evaluate import evaluate
|
||||
from coqui_stt_training.train import create_model
|
||||
from coqui_stt_training.util.config import Config, initialize_globals
|
||||
from coqui_stt_training.train import create_model, early_training_checks
|
||||
from coqui_stt_training.util.config import (
|
||||
Config,
|
||||
initialize_globals_from_cli,
|
||||
log_error,
|
||||
)
|
||||
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
||||
from coqui_stt_training.util.logging import log_error
|
||||
|
||||
|
||||
def character_based():
|
||||
is_character_based = False
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(
|
||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
||||
)
|
||||
is_character_based = scorer.is_utf8_mode()
|
||||
scorer = Scorer(
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
is_character_based = scorer.is_utf8_mode()
|
||||
return is_character_based
|
||||
|
||||
|
||||
def objective(trial):
|
||||
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max)
|
||||
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max)
|
||||
Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max)
|
||||
Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max)
|
||||
|
||||
is_character_based = trial.study.user_attrs["is_character_based"]
|
||||
|
||||
samples = []
|
||||
for step, test_file in enumerate(FLAGS.test_files.split(",")):
|
||||
for step, test_file in enumerate(Config.test_files):
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
current_samples = evaluate([test_file], create_model)
|
||||
@ -51,10 +51,18 @@ def objective(trial):
|
||||
return cer if is_character_based else wer
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
def main():
|
||||
initialize_globals_from_cli()
|
||||
early_training_checks()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
if not Config.scorer_path:
|
||||
log_error(
|
||||
"Missing --scorer_path: can't optimize scorer alpha and beta "
|
||||
"parameters without a scorer!"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not Config.test_files:
|
||||
log_error(
|
||||
"You need to specify what files to use for evaluation via "
|
||||
"the --test_files flag."
|
||||
@ -65,7 +73,7 @@ def main(_):
|
||||
|
||||
study = optuna.create_study()
|
||||
study.set_user_attr("is_character_based", is_character_based)
|
||||
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
|
||||
study.optimize(objective, n_jobs=1, n_trials=Config.n_trials)
|
||||
print(
|
||||
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
|
||||
study.best_params["lm_alpha"],
|
||||
@ -76,5 +84,4 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
main()
|
||||
|
@ -1,22 +1,9 @@
|
||||
# Description: Coqui STT native client library.
|
||||
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object", "tf_copts", "lrt_if_needed")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "lrt_if_needed")
|
||||
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
|
||||
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
||||
"tflite_copts",
|
||||
"tflite_linkopts",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "tflite",
|
||||
define_values = {
|
||||
"runtime": "tflite",
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "rpi3",
|
||||
@ -52,6 +39,31 @@ OPENFST_INCLUDES_PLATFORM = select({
|
||||
"//conditions:default": ["ctcdecode/third_party/openfst-1.6.7/src/include"],
|
||||
})
|
||||
|
||||
DECODER_SOURCES = [
|
||||
"alphabet.cc",
|
||||
"alphabet.h",
|
||||
"ctcdecode/ctc_beam_search_decoder.cpp",
|
||||
"ctcdecode/ctc_beam_search_decoder.h",
|
||||
"ctcdecode/decoder_utils.cpp",
|
||||
"ctcdecode/decoder_utils.h",
|
||||
"ctcdecode/path_trie.cpp",
|
||||
"ctcdecode/path_trie.h",
|
||||
"ctcdecode/scorer.cpp",
|
||||
"ctcdecode/scorer.h",
|
||||
] + OPENFST_SOURCES_PLATFORM
|
||||
|
||||
DECODER_INCLUDES = [
|
||||
".",
|
||||
"ctcdecode/third_party/ThreadPool",
|
||||
"ctcdecode/third_party/object_pool",
|
||||
] + OPENFST_INCLUDES_PLATFORM
|
||||
|
||||
DECODER_LINKOPTS = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
]
|
||||
|
||||
LINUX_LINKOPTS = [
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
@ -60,10 +72,12 @@ LINUX_LINKOPTS = [
|
||||
"-Wl,-export-dynamic",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "kenlm",
|
||||
cc_binary(
|
||||
name = "libkenlm.so",
|
||||
srcs = glob([
|
||||
"kenlm/lm/*.hh",
|
||||
"kenlm/lm/*.cc",
|
||||
"kenlm/util/*.hh",
|
||||
"kenlm/util/*.cc",
|
||||
"kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/util/double-conversion/*.h",
|
||||
@ -72,10 +86,36 @@ cc_library(
|
||||
"kenlm/*/*test.cc",
|
||||
"kenlm/*/*main.cc",
|
||||
],),
|
||||
copts = [
|
||||
"-std=c++11"
|
||||
] + select({
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": ["-fvisibility=hidden"],
|
||||
}),
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = ["kenlm"],
|
||||
linkshared = 1,
|
||||
linkopts = select({
|
||||
"//tensorflow:ios": [
|
||||
"-Wl,-install_name,@rpath/libkenlm.so",
|
||||
],
|
||||
"//tensorflow:macos": [
|
||||
"-Wl,-install_name,@rpath/libkenlm.so",
|
||||
],
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,-soname,libkenlm.so",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name="kenlm",
|
||||
hdrs = glob([
|
||||
"kenlm/lm/*.hh",
|
||||
"kenlm/util/*.hh",
|
||||
]),
|
||||
srcs = ["libkenlm.so"],
|
||||
copts = ["-std=c++11"],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = ["kenlm"],
|
||||
@ -83,32 +123,25 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "decoder",
|
||||
srcs = [
|
||||
"ctcdecode/ctc_beam_search_decoder.cpp",
|
||||
"ctcdecode/decoder_utils.cpp",
|
||||
"ctcdecode/decoder_utils.h",
|
||||
"ctcdecode/scorer.cpp",
|
||||
"ctcdecode/path_trie.cpp",
|
||||
"ctcdecode/path_trie.h",
|
||||
"alphabet.cc",
|
||||
] + OPENFST_SOURCES_PLATFORM,
|
||||
hdrs = [
|
||||
"ctcdecode/ctc_beam_search_decoder.h",
|
||||
"ctcdecode/scorer.h",
|
||||
"ctcdecode/decoder_utils.h",
|
||||
"alphabet.h",
|
||||
],
|
||||
includes = [
|
||||
".",
|
||||
"ctcdecode/third_party/ThreadPool",
|
||||
"ctcdecode/third_party/object_pool",
|
||||
] + OPENFST_INCLUDES_PLATFORM,
|
||||
srcs = DECODER_SOURCES,
|
||||
includes = DECODER_INCLUDES,
|
||||
deps = [":kenlm"],
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
linkopts = DECODER_LINKOPTS,
|
||||
copts = ["-fexceptions"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name="tflite",
|
||||
hdrs = [
|
||||
"//tensorflow/lite:model.h",
|
||||
"//tensorflow/lite/kernels:register.h",
|
||||
"//tensorflow/lite/tools/evaluation:utils.h",
|
||||
],
|
||||
srcs = [
|
||||
"//tensorflow/lite:libtensorflowlite.so",
|
||||
],
|
||||
includes = ["tensorflow"],
|
||||
deps = ["//tensorflow/lite:libtensorflowlite.so"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -121,17 +154,10 @@ cc_library(
|
||||
"modelstate.h",
|
||||
"workspace_status.cc",
|
||||
"workspace_status.h",
|
||||
] + select({
|
||||
"//native_client:tflite": [
|
||||
"tflitemodelstate.h",
|
||||
"tflitemodelstate.cc",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"tfmodelstate.h",
|
||||
"tfmodelstate.cc",
|
||||
],
|
||||
}),
|
||||
copts = tf_copts() + select({
|
||||
"tflitemodelstate.h",
|
||||
"tflitemodelstate.cc",
|
||||
] + DECODER_SOURCES,
|
||||
copts = select({
|
||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||
"//tensorflow:windows": ["/w"],
|
||||
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
||||
@ -140,73 +166,42 @@ cc_library(
|
||||
"-Wno-sign-compare",
|
||||
"-fvisibility=hidden",
|
||||
],
|
||||
}) + select({
|
||||
"//native_client:tflite": ["-DUSE_TFLITE"],
|
||||
"//conditions:default": ["-UUSE_TFLITE"],
|
||||
}) + tflite_copts(),
|
||||
}),
|
||||
linkopts = lrt_if_needed() + select({
|
||||
"//tensorflow:macos": [],
|
||||
"//tensorflow:ios": ["-fembed-bitcode"],
|
||||
"//tensorflow:linux_x86_64": LINUX_LINKOPTS,
|
||||
"//native_client:rpi3": LINUX_LINKOPTS,
|
||||
"//native_client:rpi3-armv8": LINUX_LINKOPTS,
|
||||
"//tensorflow:windows": [],
|
||||
# Bazel is has too strong opinions about static linking, so it's
|
||||
# near impossible to get it to link a DLL against another DLL on Windows.
|
||||
# We simply force the linker option manually here as a hacky fix.
|
||||
"//tensorflow:windows": [
|
||||
"bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib",
|
||||
"bazel-out/x64_windows-opt/bin/tensorflow/lite/libtensorflowlite.so.if.lib",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + tflite_linkopts(),
|
||||
deps = select({
|
||||
"//native_client:tflite": [
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//third_party/eigen3",
|
||||
#"//tensorflow/core:all_kernels",
|
||||
### => Trying to be more fine-grained
|
||||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
||||
### CPU only build, libstt.so file size reduced by ~50%
|
||||
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
||||
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
||||
"//tensorflow/core/kernels:cast_op", # Cast
|
||||
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
||||
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
|
||||
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
|
||||
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
||||
"//tensorflow/core/kernels:identity_op", # Identity
|
||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
||||
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
|
||||
"//tensorflow/core/kernels:matmul_op", # MatMul
|
||||
"//tensorflow/core/kernels:reduction_ops", # Max
|
||||
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
||||
"//tensorflow/core/kernels:no_op", # NoOp
|
||||
"//tensorflow/core/kernels:pack_op", # Pack
|
||||
"//tensorflow/core/kernels:sequence_ops", # Range
|
||||
"//tensorflow/core/kernels:relu_op", # Relu
|
||||
"//tensorflow/core/kernels:reshape_op", # Reshape
|
||||
"//tensorflow/core/kernels:softmax_op", # Softmax
|
||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
||||
"//tensorflow/core/kernels:transpose_op", # Transpose
|
||||
"//tensorflow/core/kernels:rnn_ops", # BlockLSTM
|
||||
# And we also need the op libs for these ops used in the model:
|
||||
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
|
||||
"//tensorflow/core:rnn_ops_op_lib", # BlockLSTM
|
||||
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
|
||||
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
|
||||
"//tensorflow/core:no_op_op_lib", # NoOp
|
||||
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
|
||||
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
|
||||
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
|
||||
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
|
||||
],
|
||||
}) + if_cuda([
|
||||
"//tensorflow/core:core",
|
||||
]) + [":decoder"],
|
||||
}) + DECODER_LINKOPTS,
|
||||
includes = DECODER_INCLUDES,
|
||||
deps = [":kenlm", ":tflite"],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
cc_binary(
|
||||
name = "libstt.so",
|
||||
deps = [":coqui_stt_bundle"],
|
||||
linkshared = 1,
|
||||
linkopts = select({
|
||||
"//tensorflow:ios": [
|
||||
"-Wl,-install_name,@rpath/libstt.so",
|
||||
],
|
||||
"//tensorflow:macos": [
|
||||
"-Wl,-install_name,@rpath/libstt.so",
|
||||
],
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,-soname,libstt.so",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
ios_static_framework(
|
||||
@ -231,9 +226,13 @@ cc_binary(
|
||||
"generate_scorer_package.cpp",
|
||||
"stt_errors.cc",
|
||||
],
|
||||
copts = ["-std=c++11"],
|
||||
copts = select({
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": ["-std=c++11"],
|
||||
}),
|
||||
deps = [
|
||||
":decoder",
|
||||
":kenlm",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -247,6 +246,10 @@ cc_binary(
|
||||
] + select({
|
||||
# ARMv7: error: Android 5.0 and later only support position-independent executables (-fPIE).
|
||||
"//tensorflow:android": ["-fPIE -pie"],
|
||||
# Bazel is has too strong opinions about static linking, so it's
|
||||
# near impossible to get it to link a DLL against another DLL on Windows.
|
||||
# We simply force the linker option manually here as a hacky fix.
|
||||
"//tensorflow:windows": ["bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
)
|
||||
@ -263,9 +266,8 @@ cc_binary(
|
||||
cc_binary(
|
||||
name = "trie_load",
|
||||
srcs = [
|
||||
"alphabet.h",
|
||||
"trie_load.cc",
|
||||
],
|
||||
] + DECODER_SOURCES,
|
||||
copts = ["-std=c++11"],
|
||||
deps = [":decoder"],
|
||||
linkopts = DECODER_LINKOPTS,
|
||||
)
|
||||
|
@ -69,6 +69,39 @@ Alphabet::init(const char *config_file)
|
||||
return 0;
|
||||
}
|
||||
|
||||
void
|
||||
Alphabet::InitFromLabels(const std::vector<std::string>& labels)
|
||||
{
|
||||
space_label_ = -2;
|
||||
size_ = labels.size();
|
||||
for (int i = 0; i < size_; ++i) {
|
||||
const std::string& label = labels[i];
|
||||
if (label == " ") {
|
||||
space_label_ = i;
|
||||
}
|
||||
label_to_str_[i] = label;
|
||||
str_to_label_[label] = i;
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::SerializeText()
|
||||
{
|
||||
std::stringstream out;
|
||||
|
||||
out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n"
|
||||
<< "# associated with a numeric label.\n"
|
||||
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
|
||||
<< "# to use '#' as a label.\n";
|
||||
|
||||
for (int label = 0; label < size_; ++label) {
|
||||
out << label_to_str_[label] << "\n";
|
||||
}
|
||||
|
||||
out << "# The last (non-comment) line needs to end with a newline.\n";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::Serialize()
|
||||
{
|
||||
|
@ -19,9 +19,15 @@ public:
|
||||
|
||||
virtual int init(const char *config_file);
|
||||
|
||||
// Initialize directly from sequence of labels.
|
||||
void InitFromLabels(const std::vector<std::string>& labels);
|
||||
|
||||
// Serialize alphabet into a binary buffer.
|
||||
std::string Serialize();
|
||||
|
||||
// Serialize alphabet into a text representation (ie. config file read by `init`)
|
||||
std::string SerializeText();
|
||||
|
||||
// Deserialize alphabet from a binary buffer.
|
||||
int Deserialize(const char* buffer, const int buffer_size);
|
||||
|
||||
|
@ -45,16 +45,16 @@ workspace_status.cc:
|
||||
# Enforce PATH here because swig calls from build_ext looses track of some
|
||||
# variables over several runs
|
||||
bindings: clean-keep-third-party workspace_status.cc $(DS_SWIG_DEP)
|
||||
python -m pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==45.0.0
|
||||
DISTUTILS_USE_SDK=1 PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN_DIR):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
python -m pip install --quiet $(PYTHON_PACKAGES) wheel setuptools
|
||||
DISTUTILS_USE_SDK=1 PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN_DIR):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(SETUP_FLAGS)
|
||||
find temp_build -type f -name "*.o" -delete
|
||||
DISTUTILS_USE_SDK=1 AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
DISTUTILS_USE_SDK=1 AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(SETUP_FLAGS)
|
||||
rm -rf temp_build
|
||||
|
||||
bindings-debug: clean-keep-third-party workspace_status.cc $(DS_SWIG_DEP)
|
||||
python -m pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==45.0.0
|
||||
DISTUTILS_USE_SDK=1 PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN_DIR):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
python -m pip install --quiet $(PYTHON_PACKAGES) wheel setuptools
|
||||
DISTUTILS_USE_SDK=1 PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN_DIR):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(SETUP_FLAGS)
|
||||
$(GENERATE_DEBUG_SYMS)
|
||||
find temp_build -type f -name "*.o" -delete
|
||||
DISTUTILS_USE_SDK=1 AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
|
||||
DISTUTILS_USE_SDK=1 AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(SETUP_FLAGS)
|
||||
rm -rf temp_build
|
||||
|
@ -45,13 +45,17 @@ class Scorer(swigwrapper.Scorer):
|
||||
class Alphabet(swigwrapper.Alphabet):
|
||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||
|
||||
def __init__(self, config_path):
|
||||
def __init__(self, config_path=None):
|
||||
super(Alphabet, self).__init__()
|
||||
err = self.init(config_path.encode("utf-8"))
|
||||
if err != 0:
|
||||
raise ValueError(
|
||||
"Alphabet initialization failed with error code 0x{:X}".format(err)
|
||||
)
|
||||
if config_path:
|
||||
err = self.init(config_path.encode("utf-8"))
|
||||
if err != 0:
|
||||
raise ValueError(
|
||||
"Alphabet initialization failed with error code 0x{:X}".format(err)
|
||||
)
|
||||
|
||||
def InitFromLabels(self, data):
|
||||
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
|
||||
|
||||
def CanEncodeSingle(self, input):
|
||||
"""
|
||||
|
@ -1,6 +1,7 @@
|
||||
NC_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
|
||||
TARGET ?= host
|
||||
ROOT_DIR ?= $(abspath $(NC_DIR)/..)
|
||||
TFDIR ?= $(abspath $(NC_DIR)/../tensorflow)
|
||||
PREFIX ?= /usr/local
|
||||
SO_SEARCH ?= $(TFDIR)/bazel-bin/
|
||||
@ -20,15 +21,15 @@ endif
|
||||
|
||||
STT_BIN := stt$(PLATFORM_EXE_SUFFIX)
|
||||
CFLAGS_STT := -std=c++11 -o $(STT_BIN)
|
||||
LINK_STT := -lstt
|
||||
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client
|
||||
LINK_STT := -lstt -lkenlm -ltensorflowlite
|
||||
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client -L${TFDIR}/bazel-bin/tensorflow/lite
|
||||
|
||||
ifeq ($(TARGET),host)
|
||||
TOOLCHAIN :=
|
||||
CFLAGS :=
|
||||
CXXFLAGS :=
|
||||
LDFLAGS :=
|
||||
SOX_CFLAGS := `pkg-config --cflags sox`
|
||||
SOX_CFLAGS := -I$(ROOT_DIR)/sox-build/include
|
||||
ifeq ($(OS),Linux)
|
||||
MAGIC_LINK_LZMA := $(shell objdump -tTC /usr/lib/`uname -m`-linux-gnu/libmagic.so | grep lzma | grep '*UND*' | wc -l)
|
||||
ifneq ($(MAGIC_LINK_LZMA),0)
|
||||
@ -38,8 +39,7 @@ MAGIC_LINK_BZ2 := $(shell objdump -tTC /usr/lib/`uname -m`-linux-gnu/libmagic.s
|
||||
ifneq ($(MAGIC_LINK_BZ2),0)
|
||||
MAYBE_LINK_BZ2 := -lbz2
|
||||
endif # MAGIC_LINK_BZ2
|
||||
SOX_CFLAGS += -fopenmp
|
||||
SOX_LDFLAGS := -Wl,-Bstatic `pkg-config --static --libs sox` -lgsm `pkg-config --static --libs libpng | cut -d' ' -f1` -lz -lmagic $(MAYBE_LINK_LZMA) $(MAYBE_LINK_BZ2) -lltdl -Wl,-Bdynamic -ldl
|
||||
SOX_LDFLAGS := -L$(ROOT_DIR)/sox-build/lib -lsox
|
||||
else ifeq ($(OS),Darwin)
|
||||
LIBSOX_PATH := $(shell echo `pkg-config --libs-only-L sox | sed -e 's/^-L//'`/lib`pkg-config --libs-only-l sox | sed -e 's/^-l//'`.dylib)
|
||||
LIBOPUSFILE_PATH := $(shell echo `pkg-config --libs-only-L opusfile | sed -e 's/^-L//'`/lib`pkg-config --libs-only-l opusfile | sed -e 's/^-l//'`.dylib)
|
||||
@ -51,7 +51,7 @@ SOX_LDFLAGS := `pkg-config --libs sox`
|
||||
endif # OS others
|
||||
PYTHON_PACKAGES := numpy${NUMPY_BUILD_VERSION}
|
||||
ifeq ($(OS),Linux)
|
||||
PYTHON_PLATFORM_NAME ?= --plat-name manylinux1_x86_64
|
||||
PYTHON_PLATFORM_NAME ?= --plat-name manylinux_2_24_x86_64
|
||||
endif
|
||||
endif
|
||||
|
||||
@ -61,7 +61,7 @@ TOOL_CC := cl.exe
|
||||
TOOL_CXX := cl.exe
|
||||
TOOL_LD := link.exe
|
||||
TOOL_LIBEXE := lib.exe
|
||||
LINK_STT := $(TFDIR)\bazel-bin\native_client\libstt.so.if.lib
|
||||
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib")
|
||||
LINK_PATH_STT :=
|
||||
CFLAGS_STT := -nologo -Fe$(STT_BIN)
|
||||
SOX_CFLAGS :=
|
||||
@ -175,7 +175,7 @@ define copy_missing_libs
|
||||
SRC_FILE=$(1); \
|
||||
TARGET_LIB_DIR=$(2); \
|
||||
MANIFEST_IN=$(3); \
|
||||
echo "Analyzing $$SRC_FILE copying missing libs to $$SRC_FILE"; \
|
||||
echo "Analyzing $$SRC_FILE copying missing libs to $$TARGET_LIB_DIR"; \
|
||||
echo "Maybe outputting to $$MANIFEST_IN"; \
|
||||
\
|
||||
(mkdir $$TARGET_LIB_DIR || true); \
|
||||
@ -185,12 +185,13 @@ define copy_missing_libs
|
||||
new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \
|
||||
missing_libs="$$missing_libs $$new_missing"; \
|
||||
elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \
|
||||
missing_libs="libstt.so"; \
|
||||
missing_libs="libstt.so libkenlm.so libtensorflowlite.so"; \
|
||||
else \
|
||||
missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \
|
||||
fi; \
|
||||
done; \
|
||||
\
|
||||
echo "Missing libs = $$missing_libs"; \
|
||||
for missing in $$missing_libs; do \
|
||||
find $(SO_SEARCH) -type f -name "$$missing" -exec cp {} $$TARGET_LIB_DIR \; ; \
|
||||
chmod +w $$TARGET_LIB_DIR/*.so ; \
|
||||
@ -237,7 +238,7 @@ DS_SWIG_ENV := SWIG_LIB="$(SWIG_LIB)" PATH="$(DS_SWIG_BIN_PATH):${PATH}"
|
||||
|
||||
$(DS_SWIG_BIN_PATH)/swig:
|
||||
mkdir -p $(SWIG_ROOT)
|
||||
wget -O - "$(SWIG_DIST_URL)" | tar -C $(SWIG_ROOT) -zxf -
|
||||
curl -sSL "$(SWIG_DIST_URL)" | tar -C $(SWIG_ROOT) -zxf -
|
||||
ln -s $(DS_SWIG_BIN) $(DS_SWIG_BIN_PATH)/$(SWIG_BIN)
|
||||
|
||||
ds-swig: $(DS_SWIG_BIN_PATH)/swig
|
||||
|
@ -50,7 +50,7 @@ configure: stt_wrap.cxx package.json npm-dev
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" $(NODE_BUILD_TOOL) configure $(NODE_BUILD_VERBOSE)
|
||||
|
||||
build: configure stt_wrap.cxx
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" NODE_PRE_GYP_ABI_CROSSWALK=$(NODE_PRE_GYP_ABI_CROSSWALK_FILE) AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" CXXFLAGS="$(CXXFLAGS)" LDFLAGS="$(RPATH_NODEJS) $(LDFLAGS)" LIBS=$(LIBS) $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DEVDIR) $(NODE_DIST_URL) --no-color rebuild $(NODE_BUILD_VERBOSE)
|
||||
PATH="$(NODE_MODULES_BIN):${PATH}" NODE_PRE_GYP_ABI_CROSSWALK=$(NODE_PRE_GYP_ABI_CROSSWALK_FILE) AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" CXXFLAGS="$(CXXFLAGS)" LDFLAGS="$(RPATH_NODEJS) $(LDFLAGS)" LIBS="$(LIBS)" $(NODE_BUILD_TOOL) $(NODE_PLATFORM_TARGET) $(NODE_RUNTIME) $(NODE_ABI_TARGET) $(NODE_DEVDIR) $(NODE_DIST_URL) --no-color rebuild $(NODE_BUILD_VERBOSE)
|
||||
|
||||
copy-deps: build
|
||||
$(call copy_missing_libs,lib/binding/*/*/*/stt.node,lib/binding/*/*/)
|
||||
@ -63,3 +63,6 @@ npm-pack: clean package.json index.ts npm-dev
|
||||
|
||||
stt_wrap.cxx: stt.i ds-swig
|
||||
$(DS_SWIG_ENV) swig -c++ -javascript -node stt.i
|
||||
# Hack: disable wrapping of constructors to avoid NodeJS 16.6 ABI compat break
|
||||
sed -i.bak '/SetCallHandler/d' stt_wrap.cxx
|
||||
rm stt_wrap.cxx.bak
|
||||
|
@ -3,7 +3,7 @@
|
||||
{
|
||||
"target_name": "stt",
|
||||
"sources": ["stt_wrap.cxx"],
|
||||
"libraries": ["$(LIBS)"],
|
||||
"libraries": [],
|
||||
"include_dirs": ["../"],
|
||||
"conditions": [
|
||||
[
|
||||
@ -20,7 +20,24 @@
|
||||
],
|
||||
}
|
||||
},
|
||||
]
|
||||
],
|
||||
[
|
||||
"OS=='win'",
|
||||
{
|
||||
"libraries": [
|
||||
"../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib",
|
||||
"../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib",
|
||||
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib",
|
||||
],
|
||||
},
|
||||
{
|
||||
"libraries": [
|
||||
"../../../tensorflow/bazel-bin/native_client/libstt.so",
|
||||
"../../../tensorflow/bazel-bin/native_client/libkenlm.so",
|
||||
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so",
|
||||
],
|
||||
},
|
||||
],
|
||||
],
|
||||
},
|
||||
{
|
||||
|
@ -13,3 +13,84 @@ git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/do
|
||||
|
||||
Cherry-pick fix for MSVC:
|
||||
curl -vsSL https://github.com/kpu/kenlm/commit/d70e28403f07e88b276c6bd9f162d2a428530f2e.patch | git am -p1 --directory=native_client/kenlm
|
||||
|
||||
Most of the KenLM code is licensed under the LGPL. There are exceptions that
|
||||
have their own licenses, listed below. See comments in those files for more
|
||||
details.
|
||||
|
||||
util/getopt.* is getopt for Windows
|
||||
util/murmur_hash.cc
|
||||
util/string_piece.hh and util/string_piece.cc
|
||||
util/double-conversion/LICENSE covers util/double-conversion except the build files
|
||||
util/file.cc contains a modified implementation of mkstemp under the LGPL
|
||||
util/integer_to_string.* is BSD
|
||||
|
||||
For the rest:
|
||||
|
||||
KenLM is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Lesser General Public License as published
|
||||
by the Free Software Foundation, either version 2.1 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
KenLM is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Lesser General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Lesser General Public License 2.1
|
||||
along with KenLM code. If not, see <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
|
||||
|
||||
|
||||
util/double-conversion:
|
||||
|
||||
Copyright 2006-2011, the V8 project authors. All rights reserved.
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following
|
||||
disclaimer in the documentation and/or other materials provided
|
||||
with the distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
|
||||
util/integer_to_string.*:
|
||||
|
||||
Copyright (C) 2014 Milo Yip
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
@ -23,7 +23,7 @@ extern const char *kModelNames[6];
|
||||
* If so, return true and set recognized to the type. This is the only API in
|
||||
* this header designed for use by decoder authors.
|
||||
*/
|
||||
bool RecognizeBinary(const char *file, ModelType &recognized);
|
||||
KENLM_EXPORT bool RecognizeBinary(const char *file, ModelType &recognized);
|
||||
|
||||
struct FixedWidthParameters {
|
||||
unsigned char order;
|
||||
|
@ -10,13 +10,19 @@
|
||||
|
||||
/* Configuration for ngram model. Separate header to reduce pollution. */
|
||||
|
||||
#if defined _MSC_VER
|
||||
#define KENLM_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define KENLM_EXPORT __attribute__ ((visibility("default")))
|
||||
#endif /* _MSC_VER */
|
||||
|
||||
namespace lm {
|
||||
|
||||
class EnumerateVocab;
|
||||
|
||||
namespace ngram {
|
||||
|
||||
struct Config {
|
||||
struct KENLM_EXPORT Config {
|
||||
// EFFECTIVE FOR BOTH ARPA AND BINARY READS
|
||||
|
||||
// (default true) print progress bar to messages
|
||||
|
@ -149,7 +149,7 @@ typedef ProbingModel Model;
|
||||
/* Autorecognize the file type, load, and return the virtual base class. Don't
|
||||
* use the virtual base class if you can avoid it. Instead, use the above
|
||||
* classes as template arguments to your own virtual feature function.*/
|
||||
base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
|
||||
KENLM_EXPORT base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
|
||||
|
||||
} // namespace ngram
|
||||
} // namespace lm
|
||||
|
@ -10,9 +10,16 @@
|
||||
#include <string>
|
||||
#include <stdint.h>
|
||||
|
||||
#if defined _MSC_VER
|
||||
#define KENLM_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define KENLM_EXPORT __attribute__ ((visibility("default")))
|
||||
#endif /* _MSC_VER */
|
||||
|
||||
|
||||
namespace util {
|
||||
|
||||
class scoped_fd {
|
||||
class KENLM_EXPORT scoped_fd {
|
||||
public:
|
||||
scoped_fd() : fd_(-1) {}
|
||||
|
||||
@ -82,7 +89,7 @@ class EndOfFileException : public Exception {
|
||||
class UnsupportedOSException : public Exception {};
|
||||
|
||||
// Open for read only.
|
||||
int OpenReadOrThrow(const char *name);
|
||||
KENLM_EXPORT int OpenReadOrThrow(const char *name);
|
||||
// Create file if it doesn't exist, truncate if it does. Opened for write.
|
||||
int CreateOrThrow(const char *name);
|
||||
|
||||
@ -110,7 +117,7 @@ bool OutputPathIsStdout(StringPiece path);
|
||||
|
||||
// Return value for SizeFile when it can't size properly.
|
||||
const uint64_t kBadSize = (uint64_t)-1;
|
||||
uint64_t SizeFile(int fd);
|
||||
KENLM_EXPORT uint64_t SizeFile(int fd);
|
||||
uint64_t SizeOrThrow(int fd);
|
||||
|
||||
void ResizeOrThrow(int fd, uint64_t to);
|
||||
|
@ -9,7 +9,7 @@ bindings-clean:
|
||||
# Enforce PATH here because swig calls from build_ext looses track of some
|
||||
# variables over several runs
|
||||
bindings-build: ds-swig
|
||||
pip3 install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==45.0.0
|
||||
pip3 install --quiet $(PYTHON_PACKAGES) wheel setuptools
|
||||
DISTUTILS_USE_SDK=1 PATH=$(TOOLCHAIN_DIR):$(DS_SWIG_BIN_PATH):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python3 ./setup.py build_ext $(PYTHON_PLATFORM_NAME)
|
||||
|
||||
MANIFEST.in: bindings-build
|
||||
|
@ -14,13 +14,7 @@
|
||||
#include "modelstate.h"
|
||||
|
||||
#include "workspace_status.h"
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
#include "tfmodelstate.h"
|
||||
#else
|
||||
#include "tflitemodelstate.h"
|
||||
#endif // USE_TFLITE
|
||||
|
||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
@ -300,13 +294,7 @@ STT_CreateModel(const char* aModelPath,
|
||||
return STT_ERR_NO_MODEL;
|
||||
}
|
||||
|
||||
std::unique_ptr<ModelState> model(
|
||||
#ifndef USE_TFLITE
|
||||
new TFModelState()
|
||||
#else
|
||||
new TFLiteModelState()
|
||||
#endif
|
||||
);
|
||||
std::unique_ptr<ModelState> model(new TFLiteModelState());
|
||||
|
||||
if (!model) {
|
||||
std::cerr << "Could not allocate model state." << std::endl;
|
||||
|
@ -1,263 +0,0 @@
|
||||
#include "tfmodelstate.h"
|
||||
|
||||
#include "workspace_status.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using std::vector;
|
||||
|
||||
TFModelState::TFModelState()
|
||||
: ModelState()
|
||||
, mmap_env_(nullptr)
|
||||
, session_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
TFModelState::~TFModelState()
|
||||
{
|
||||
if (session_) {
|
||||
Status status = session_->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::init(const char* model_path)
|
||||
{
|
||||
int err = ModelState::init(model_path);
|
||||
if (err != STT_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
|
||||
mmap_env_.reset(new MemmappedEnv(Env::Default()));
|
||||
|
||||
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
|
||||
if (!is_mmap) {
|
||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||
} else {
|
||||
status = mmap_env_->InitializeFromFile(model_path);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::OptimizerOptions::L0);
|
||||
options.env = mmap_env_.get();
|
||||
}
|
||||
|
||||
Session* session;
|
||||
status = NewSession(options, &session);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_INIT_SESS;
|
||||
}
|
||||
session_.reset(session);
|
||||
|
||||
if (is_mmap) {
|
||||
status = ReadBinaryProto(mmap_env_.get(),
|
||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&graph_def_);
|
||||
} else {
|
||||
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_READ_PROTOBUF;
|
||||
}
|
||||
|
||||
status = session_->Create(graph_def_);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
std::vector<tensorflow::Tensor> version_output;
|
||||
status = session_->Run({}, {
|
||||
"metadata_version"
|
||||
}, {}, &version_output);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Unable to fetch graph version: " << status << std::endl;
|
||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
int graph_version = version_output[0].scalar<int>()();
|
||||
if (graph_version < ds_graph_version()) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< ds_graph_version() << "). See "
|
||||
<< "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
std::vector<tensorflow::Tensor> metadata_outputs;
|
||||
status = session_->Run({}, {
|
||||
"metadata_sample_rate",
|
||||
"metadata_feature_win_len",
|
||||
"metadata_feature_win_step",
|
||||
"metadata_beam_width",
|
||||
"metadata_alphabet",
|
||||
}, {}, &metadata_outputs);
|
||||
if (!status.ok()) {
|
||||
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
sample_rate_ = metadata_outputs[0].scalar<int>()();
|
||||
int win_len_ms = metadata_outputs[1].scalar<int>()();
|
||||
int win_step_ms = metadata_outputs[2].scalar<int>()();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
int beam_width = metadata_outputs[3].scalar<int>()();
|
||||
beam_width_ = (unsigned int)(beam_width);
|
||||
|
||||
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
||||
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||
if (err != 0) {
|
||||
return STT_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
||||
assert(sample_rate_ > 0);
|
||||
assert(audio_win_len_ > 0);
|
||||
assert(audio_win_step_ > 0);
|
||||
assert(beam_width_ > 0);
|
||||
assert(alphabet_.GetSize() > 0);
|
||||
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
NodeDef node = graph_def_.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
n_steps_ = shape.dim(1).size();
|
||||
n_context_ = (shape.dim(2).size()-1)/2;
|
||||
n_features_ = shape.dim(3).size();
|
||||
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "previous_state_c") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
state_size_ = shape.dim(1).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != alphabet_.GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_.GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return STT_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (n_context_ == -1 || n_features_ == -1) {
|
||||
std::cerr << "Error: Could not infer input shape from model file. "
|
||||
<< "Make sure input_node is a 4D tensor with shape "
|
||||
<< "[batch_size=1, time, window_size, n_features]."
|
||||
<< std::endl;
|
||||
return STT_ERR_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
return STT_ERR_OK;
|
||||
}
|
||||
|
||||
Tensor
|
||||
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
|
||||
{
|
||||
Tensor ret(DT_FLOAT, shape);
|
||||
auto ret_mapped = ret.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < vec.size(); ++i) {
|
||||
ret_mapped(i) = vec[i];
|
||||
}
|
||||
for (; i < shape.num_elements(); ++i) {
|
||||
ret_mapped(i) = 0.f;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
|
||||
{
|
||||
auto tensor_mapped = tensor.flat<float>();
|
||||
if (num_elements == -1) {
|
||||
num_elements = tensor.shape().num_elements();
|
||||
}
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
vec.push_back(tensor_mapped(i));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
vector<float>& logits_output,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
||||
|
||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run(
|
||||
{
|
||||
{"input_node", input},
|
||||
{"input_lengths", input_lengths},
|
||||
{"previous_state_c", previous_state_c_t},
|
||||
{"previous_state_h", previous_state_h_t}
|
||||
},
|
||||
{"logits", "new_state_c", "new_state_h"},
|
||||
{},
|
||||
&outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
|
||||
|
||||
state_c_output.clear();
|
||||
state_c_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[1], state_c_output);
|
||||
|
||||
state_h_output.clear();
|
||||
state_h_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[2], state_h_output);
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
||||
copy_tensor_to_vector(outputs[0], mfcc_output);
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
#ifndef TFMODELSTATE_H
|
||||
#define TFMODELSTATE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
struct TFModelState : public ModelState
|
||||
{
|
||||
std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
|
||||
std::unique_ptr<tensorflow::Session> session_;
|
||||
tensorflow::GraphDef graph_def_;
|
||||
|
||||
TFModelState();
|
||||
virtual ~TFModelState();
|
||||
|
||||
virtual int init(const char* model_path) override;
|
||||
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
};
|
||||
|
||||
#endif // TFMODELSTATE_H
|
4
notebooks/README.md
Normal file
4
notebooks/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Python Notebooks for 🐸 STT
|
||||
|
||||
1. Train a new Speech-to-Text model from scratch [](https://colab.research.google.com/github/coqui-ai/STT/blob/main/notebooks/train-your-first-coqui-STT-model.ipynb)
|
||||
2. Transfer learning (English --> Russian) [](https://colab.research.google.com/github/coqui-ai/STT/blob/main/notebooks/easy-transfer-learning.ipynb)
|
281
notebooks/easy-transfer-learning.ipynb
Normal file
281
notebooks/easy-transfer-learning.ipynb
Normal file
@ -0,0 +1,281 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45ea3ef5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Easy transfer learning with 🐸 STT ⚡\n",
|
||||
"\n",
|
||||
"You want to train a Coqui (🐸) STT model, but you don't have a lot of data. What do you do?\n",
|
||||
"\n",
|
||||
"The answer 💡: Grab a pre-trained model and fine-tune it to your data. This is called `\"Transfer Learning\"` ⚡\n",
|
||||
"\n",
|
||||
"🐸 STT comes with transfer learning support out-of-the box.\n",
|
||||
"\n",
|
||||
"You can even take a pre-trained model and fine-tune it to _any new language_, even if the alphabets are completely different. Likewise, you can fine-tune a model to your own data and improve performance if the language is the same.\n",
|
||||
"\n",
|
||||
"In this notebook, we will:\n",
|
||||
"\n",
|
||||
"1. Download a pre-trained English STT model.\n",
|
||||
"2. Download data for the Russian language.\n",
|
||||
"3. Fine-tune the English model to Russian language.\n",
|
||||
"4. Test the new Russian model and display its performance.\n",
|
||||
"\n",
|
||||
"So, let's jump right in!\n",
|
||||
"\n",
|
||||
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fa2aec77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Install Coqui STT if you need to\n",
|
||||
"# !git clone --depth 1 https://github.com/coqui-ai/STT.git\n",
|
||||
"# !cd STT; pip install -U pip wheel setuptools; pip install ."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8c07a273",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Download pre-trained English model\n",
|
||||
"\n",
|
||||
"We're going to download a very small (but very accurate) pre-trained STT model for English. This model was trained to only transcribe the English words \"yes\" and \"no\", but with transfer learning we can train a new model which could transcribe any words in any language. In this notebook, we will turn this \"constrained vocabulary\" English model into an \"open vocabulary\" Russian model.\n",
|
||||
"\n",
|
||||
"Coqui STT models as typically stored as checkpoints (for training) and protobufs (for deployment). For transfer learning, we want the **model checkpoints**.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "608d203f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"### Download pre-trained model\n",
|
||||
"import os\n",
|
||||
"import tarfile\n",
|
||||
"from coqui_stt_training.util.downloader import maybe_download\n",
|
||||
"\n",
|
||||
"def download_pretrained_model():\n",
|
||||
" model_dir=\"english/\"\n",
|
||||
" if not os.path.exists(\"english/coqui-yesno-checkpoints\"):\n",
|
||||
" maybe_download(\"model.tar.gz\", model_dir, \"https://github.com/coqui-ai/STT-models/releases/download/english%2Fcoqui%2Fyesno-v0.0.1/coqui-yesno-checkpoints.tar.gz\")\n",
|
||||
" print('\\nNo extracted pre-trained model found. Extracting now...')\n",
|
||||
" tar = tarfile.open(\"english/model.tar.gz\")\n",
|
||||
" tar.extractall(\"english/\")\n",
|
||||
" tar.close()\n",
|
||||
" else:\n",
|
||||
" print('Found \"english/coqui-yesno-checkpoints\" - not extracting.')\n",
|
||||
"\n",
|
||||
"# Download + extract pre-trained English model\n",
|
||||
"download_pretrained_model()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed9dd7ab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Download data for Russian\n",
|
||||
"\n",
|
||||
"**First things first**: we need some data.\n",
|
||||
"\n",
|
||||
"We're training a Speech-to-Text model, so we need some _speech_ and we need some _text_. Specificially, we want _transcribed speech_. Let's download a Russian audio file and its transcript, pre-formatted for 🐸 STT. \n",
|
||||
"\n",
|
||||
"**Second things second**: we want a Russian alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet. Let's download a Russian alphabet from Coqui and use that.\n",
|
||||
"\n",
|
||||
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5105ea7",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"### Download sample data\n",
|
||||
"from coqui_stt_training.util.downloader import maybe_download\n",
|
||||
"\n",
|
||||
"def download_sample_data():\n",
|
||||
" data_dir=\"russian/\"\n",
|
||||
" maybe_download(\"ru.wav\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/ru.wav\")\n",
|
||||
" maybe_download(\"ru.csv\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/ru.csv\")\n",
|
||||
" maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/smoke_test/russian_sample_data/alphabet.ru\")\n",
|
||||
"\n",
|
||||
"# Download sample Russian data\n",
|
||||
"download_sample_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b46b7227",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Configure the training run\n",
|
||||
"\n",
|
||||
"Coqui STT comes with a long list of hyperparameters you can tweak. We've set default values, but you can use `initialize_globals_from_args()` to set your own. \n",
|
||||
"\n",
|
||||
"You must **always** configure the paths to your data, and you must **always** configure your alphabet. For transfer learning, it's good practice to define different `load_checkpoint_dir` and `save_checkpoint_dir` paths so that you keep your new model (Russian STT) separate from the old one (English STT). The parameter `drop_source_layers` allows you to remove layers from the original (aka \"source\") model, and re-initialize them from scratch. If you are fine-tuning to a new alphabet you will have to use _at least_ `drop_source_layers=1` to remove the output layer and add a new output layer which matches your new alphabet.\n",
|
||||
"\n",
|
||||
"We are fine-tuning a pre-existing model, so `n_hidden` should be the same as the original English model."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cff3c5a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.util.config import initialize_globals_from_args\n",
|
||||
"\n",
|
||||
"initialize_globals_from_args(\n",
|
||||
" n_hidden=64,\n",
|
||||
" load_checkpoint_dir=\"english/coqui-yesno-checkpoints\",\n",
|
||||
" save_checkpoint_dir=\"russian/checkpoints\",\n",
|
||||
" drop_source_layers=1,\n",
|
||||
" alphabet_config_path=\"russian/alphabet.txt\",\n",
|
||||
" train_files=[\"russian/ru.csv\"],\n",
|
||||
" dev_files=[\"russian/ru.csv\"],\n",
|
||||
" epochs=200,\n",
|
||||
" load_cudnn=True,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "419828c1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### View all Config settings (*Optional*) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cac6ea3d",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.util.config import Config\n",
|
||||
"\n",
|
||||
"print(Config.to_json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c8e700d1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Train a new Russian model\n",
|
||||
"\n",
|
||||
"Let's kick off a training run 🚀🚀🚀 (using the configure you set above).\n",
|
||||
"\n",
|
||||
"This notebook should work on either a GPU or a CPU. However, in case you're running this on _multiple_ GPUs we want to only use one, because the sample dataset (one audio file) is too small to split across multiple GPUs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8aab2195",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import train\n",
|
||||
"\n",
|
||||
"# use maximum one GPU\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"\n",
|
||||
"train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3c87ba61",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Configure the testing run\n",
|
||||
"\n",
|
||||
"Let's add the path to our testing data and update `load_checkpoint_dir` to our new model checkpoints."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2be7beb5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.util.config import Config\n",
|
||||
"\n",
|
||||
"Config.test_files=[\"russian/ru.csv\"]\n",
|
||||
"Config.load_checkpoint_dir=\"russian/checkpoints\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c6a5c971",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Test the new Russian model\n",
|
||||
"\n",
|
||||
"We made it! 🙌\n",
|
||||
"\n",
|
||||
"Let's kick off the testing run, which displays performance metrics.\n",
|
||||
"\n",
|
||||
"We're committing the cardinal sin of ML 😈 (aka - testing on our training data) so you don't want to deploy this model into production. In this notebook we're focusing on the workflow itself, so it's forgivable 😇\n",
|
||||
"\n",
|
||||
"You can see from the test output that our tiny model has overfit to the data, and basically memorized this one sentence.\n",
|
||||
"\n",
|
||||
"When you start training your own models, make sure your testing data doesn't include your training data 😅"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6222dc69",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import test\n",
|
||||
"\n",
|
||||
"test()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
260
notebooks/train-your-first-coqui-STT-model.ipynb
Normal file
260
notebooks/train-your-first-coqui-STT-model.ipynb
Normal file
@ -0,0 +1,260 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f79d99ef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Train your first 🐸 STT model 💫\n",
|
||||
"\n",
|
||||
"👋 Hello and welcome to Coqui (🐸) STT \n",
|
||||
"\n",
|
||||
"The goal of this notebook is to show you a **typical workflow** for **training** and **testing** an STT model with 🐸.\n",
|
||||
"\n",
|
||||
"Let's train a very small model on a very small amount of data so we can iterate quickly.\n",
|
||||
"\n",
|
||||
"In this notebook, we will:\n",
|
||||
"\n",
|
||||
"1. Download data and format it for 🐸 STT.\n",
|
||||
"2. Configure the training and testing runs.\n",
|
||||
"3. Train a new model.\n",
|
||||
"4. Test the model and display its performance.\n",
|
||||
"\n",
|
||||
"So, let's jump right in!\n",
|
||||
"\n",
|
||||
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fa2aec78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Install Coqui STT if you need to\n",
|
||||
"# !git clone --depth 1 https://github.com/coqui-ai/STT.git\n",
|
||||
"# !cd STT; pip install -U pip wheel setuptools; pip install ."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "be5fe49c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Download & format sample data for English\n",
|
||||
"\n",
|
||||
"**First things first**: we need some data.\n",
|
||||
"\n",
|
||||
"We're training a Speech-to-Text model, so we need some _speech_ and we need some _text_. Specificially, we want _transcribed speech_. Let's download an English audio file and its transcript and then format them for 🐸 STT. \n",
|
||||
"\n",
|
||||
"🐸 STT expects to find information about your data in a CSV file, where each line contains:\n",
|
||||
"\n",
|
||||
"1. the **path** to an audio file\n",
|
||||
"2. the **size** of that audio file\n",
|
||||
"3. the **transcript** of that audio file.\n",
|
||||
"\n",
|
||||
"Formatting the audio and transcript isn't too difficult in this case. We define a custom data importer called `download_sample_data()` which does all the work. If you have a custom dataset, you will probably want to write a custom data importer.\n",
|
||||
"\n",
|
||||
"**Second things second**: we want an alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet, and you should specify this alphabet before training. Let's download an English alphabet from Coqui and use that.\n",
|
||||
"\n",
|
||||
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "53945462",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"### Download sample data\n",
|
||||
"import os\n",
|
||||
"import pandas\n",
|
||||
"from coqui_stt_training.util.downloader import maybe_download\n",
|
||||
"\n",
|
||||
"def download_sample_data():\n",
|
||||
" data_dir=\"english/\"\n",
|
||||
" # Download data + alphabet\n",
|
||||
" audio_file = maybe_download(\"LDC93S1.wav\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav\")\n",
|
||||
" transcript_file = maybe_download(\"LDC93S1.txt\", data_dir, \"https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.txt\")\n",
|
||||
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
|
||||
" # Format data\n",
|
||||
" with open(transcript_file, \"r\") as fin:\n",
|
||||
" transcript = \" \".join(fin.read().strip().lower().split(\" \")[2:]).replace(\".\", \"\")\n",
|
||||
" df = pandas.DataFrame(data=[(os.path.abspath(audio_file), os.path.getsize(audio_file), transcript)],\n",
|
||||
" columns=[\"wav_filename\", \"wav_filesize\", \"transcript\"])\n",
|
||||
" # Save formatted CSV \n",
|
||||
" df.to_csv(os.path.join(data_dir, \"ldc93s1.csv\"), index=False)\n",
|
||||
"\n",
|
||||
"# Download and format data\n",
|
||||
"download_sample_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96e8b708",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Take a look at the data (*Optional* )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fa2aec77",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"csv_file = open(\"english/ldc93s1.csv\", \"r\")\n",
|
||||
"print(csv_file.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6c046277",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"alphabet_file = open(\"english/alphabet.txt\", \"r\")\n",
|
||||
"print(alphabet_file.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d9dfac21",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Configure & set hyperparameters\n",
|
||||
"\n",
|
||||
"Coqui STT comes with a long list of hyperparameters you can tweak. We've set default values, but you will often want to set your own. You can use `initialize_globals_from_args()` to do this. \n",
|
||||
"\n",
|
||||
"You must **always** configure the paths to your data, and you must **always** configure your alphabet. Additionally, here we show how you can specify the size of hidden layers (`n_hidden`), the number of epochs to train for (`epochs`), and to initialize a new model from scratch (`load_train=\"init\"`)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d264fdec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.util.config import initialize_globals_from_args\n",
|
||||
"\n",
|
||||
"initialize_globals_from_args(\n",
|
||||
" alphabet_config_path=\"english/alphabet.txt\",\n",
|
||||
" train_files=[\"english/ldc93s1.csv\"],\n",
|
||||
" dev_files=[\"english/ldc93s1.csv\"],\n",
|
||||
" test_files=[\"english/ldc93s1.csv\"],\n",
|
||||
" load_train=\"init\",\n",
|
||||
" n_hidden=100,\n",
|
||||
" epochs=200,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "799c1425",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### View all Config settings (*Optional*) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03b33d2b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.util.config import Config\n",
|
||||
"\n",
|
||||
"# Take a peek at the entire Config\n",
|
||||
"print(Config.to_json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ae82fd75",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Train a new model\n",
|
||||
"\n",
|
||||
"Let's kick off a training run 🚀🚀🚀 (using the configure you set above).\n",
|
||||
"\n",
|
||||
"This notebook should work on either a GPU or a CPU. However, in case you're running this on _multiple_ GPUs we want to only use one, because the sample dataset (one audio file) is too small to split across multiple GPUs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "550a504e",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import train\n",
|
||||
"\n",
|
||||
"# use maximum one GPU\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"\n",
|
||||
"train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9f6dc959",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ✅ Test the model\n",
|
||||
"\n",
|
||||
"We made it! 🙌\n",
|
||||
"\n",
|
||||
"Let's kick off the testing run, which displays performance metrics.\n",
|
||||
"\n",
|
||||
"We're committing the cardinal sin of ML 😈 (aka - testing on our training data) so you don't want to deploy this model into production. In this notebook we're focusing on the workflow itself, so it's forgivable 😇\n",
|
||||
"\n",
|
||||
"You can see from the test output that our tiny model has overfit to the data, and basically memorized this one sentence.\n",
|
||||
"\n",
|
||||
"When you start training your own models, make sure your testing data doesn't include your training data 😅"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd42bc7a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import test\n",
|
||||
"\n",
|
||||
"test()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
13
setup.py
13
setup.py
@ -8,7 +8,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def main():
|
||||
version_file = Path(__file__).parent / "VERSION"
|
||||
version_file = Path(__file__).parent / "training" / "coqui_stt_training" / "VERSION"
|
||||
with open(str(version_file)) as fin:
|
||||
version = fin.read().strip()
|
||||
|
||||
@ -18,6 +18,7 @@ def main():
|
||||
"coqpit",
|
||||
"numpy",
|
||||
"optuna",
|
||||
"numba <= 0.53.1",
|
||||
"opuslib == 2.0.0",
|
||||
"pandas",
|
||||
"progressbar2",
|
||||
@ -29,6 +30,7 @@ def main():
|
||||
"six",
|
||||
"sox",
|
||||
"soundfile",
|
||||
"tqdm",
|
||||
]
|
||||
|
||||
decoder_pypi_dep = ["coqui_stt_ctcdecoder == {}".format(version)]
|
||||
@ -66,14 +68,7 @@ def main():
|
||||
packages=find_packages(where="training"),
|
||||
python_requires=">=3.5, <4",
|
||||
install_requires=install_requires,
|
||||
# If there are data files included in your packages that need to be
|
||||
# installed, specify them here.
|
||||
package_data={
|
||||
"coqui_stt_training": [
|
||||
"VERSION",
|
||||
"GRAPH_VERSION",
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
0.10.0-alpha.9
|
||||
0.10.0-alpha.14
|
||||
|
403
training/coqui_stt_training/deepspeech_model.py
Normal file
403
training/coqui_stt_training/deepspeech_model.py
Normal file
@ -0,0 +1,403 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = (
|
||||
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
||||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
tfv1.logging.set_verbosity(
|
||||
{
|
||||
"0": tfv1.logging.DEBUG,
|
||||
"1": tfv1.logging.INFO,
|
||||
"2": tfv1.logging.WARN,
|
||||
"3": tfv1.logging.ERROR,
|
||||
}.get(DESIRED_LOG_LEVEL)
|
||||
)
|
||||
|
||||
from .util.config import Config
|
||||
from .util.feeding import audio_to_features
|
||||
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(
|
||||
np.eye(window_width * num_channels).reshape(
|
||||
window_width, num_channels, window_width * num_channels
|
||||
),
|
||||
tf.float32,
|
||||
) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding="SAME")
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu("bias", [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu(
|
||||
"weights",
|
||||
[x.shape[-1], units],
|
||||
tfv1.keras.initializers.VarianceScaling(
|
||||
scale=1.0, mode="fan_avg", distribution="uniform"
|
||||
),
|
||||
)
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), Config.relu_clip)
|
||||
|
||||
if layer_norm:
|
||||
with tfv1.variable_scope(name):
|
||||
output = tf.contrib.layers.layer_norm(output)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope("cudnn_lstm/rnn/multi_rnn_cell/cell_0"):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(
|
||||
Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name="cudnn_compatible_lstm_cell",
|
||||
)
|
||||
|
||||
output, output_state = fw_cell(
|
||||
inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert (
|
||||
previous_state is None
|
||||
) # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(
|
||||
num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode="linear_input",
|
||||
direction="unidirectional",
|
||||
dtype=tf.float32,
|
||||
)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(
|
||||
inputs=x, sequence_lengths=seq_length
|
||||
)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope("cudnn_lstm/rnn/multi_rnn_cell"):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(
|
||||
Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name="cudnn_compatible_lstm_cell",
|
||||
)
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(
|
||||
cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope="cell_0",
|
||||
)
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(
|
||||
batch_x,
|
||||
seq_length,
|
||||
dropout,
|
||||
reuse=False,
|
||||
batch_size=None,
|
||||
previous_state=None,
|
||||
overlap=True,
|
||||
rnn_impl=rnn_impl_lstmblockfusedcell,
|
||||
):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(
|
||||
batch_x, [-1, Config.n_input + 2 * Config.n_input * Config.n_context]
|
||||
) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers["input_reshaped"] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers["layer_1"] = layer_1 = dense(
|
||||
"layer_1",
|
||||
batch_x,
|
||||
Config.n_hidden_1,
|
||||
dropout_rate=dropout[0],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
layers["layer_2"] = layer_2 = dense(
|
||||
"layer_2",
|
||||
layer_1,
|
||||
Config.n_hidden_2,
|
||||
dropout_rate=dropout[1],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
layers["layer_3"] = layer_3 = dense(
|
||||
"layer_3",
|
||||
layer_2,
|
||||
Config.n_hidden_3,
|
||||
dropout_rate=dropout[2],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers["rnn_output"] = output
|
||||
layers["rnn_output_state"] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers["layer_5"] = layer_5 = dense(
|
||||
"layer_5",
|
||||
output,
|
||||
Config.n_hidden_5,
|
||||
dropout_rate=dropout[5],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers["layer_6"] = layer_6 = dense(
|
||||
"layer_6", layer_5, Config.n_hidden_6, relu=False
|
||||
)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(
|
||||
layer_6, [-1, batch_size, Config.n_hidden_6], name="raw_logits"
|
||||
)
|
||||
layers["raw_logits"] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
input_samples = tfv1.placeholder(
|
||||
tf.float32, [Config.audio_window_samples], "input_samples"
|
||||
)
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
mfccs = tf.identity(mfccs, name="mfccs")
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in STT_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
#
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
input_tensor = tfv1.placeholder(
|
||||
tf.float32,
|
||||
[
|
||||
batch_size,
|
||||
n_steps if n_steps > 0 else None,
|
||||
2 * Config.n_context + 1,
|
||||
Config.n_input,
|
||||
],
|
||||
name="input_node",
|
||||
)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
previous_state_c = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
|
||||
)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
previous_state_h = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
|
||||
)
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
|
||||
previous_state_c, previous_state_h
|
||||
)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(
|
||||
batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not Config.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl,
|
||||
)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
#
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
probs = tf.nn.softmax(logits, name="logits")
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError(
|
||||
"dynamic batch_size does not support tflite nor streaming"
|
||||
)
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError(
|
||||
"dynamic batch_size expect n_steps to be dynamic too"
|
||||
)
|
||||
return (
|
||||
{
|
||||
"input": input_tensor,
|
||||
"input_lengths": seq_length,
|
||||
},
|
||||
{
|
||||
"outputs": probs,
|
||||
},
|
||||
layers,
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers["rnn_output_state"]
|
||||
new_state_c = tf.identity(new_state_c, name="new_state_c")
|
||||
new_state_h = tf.identity(new_state_h, name="new_state_h")
|
||||
|
||||
inputs = {
|
||||
"input": input_tensor,
|
||||
"previous_state_c": previous_state_c,
|
||||
"previous_state_h": previous_state_h,
|
||||
"input_samples": input_samples,
|
||||
}
|
||||
|
||||
if not Config.export_tflite:
|
||||
inputs["input_lengths"] = seq_length
|
||||
|
||||
outputs = {
|
||||
"outputs": probs,
|
||||
"new_state_c": new_state_c,
|
||||
"new_state_h": new_state_h,
|
||||
"mfccs": mfccs,
|
||||
# Expose internal layers for downstream applications
|
||||
"layer_3": layers["layer_3"],
|
||||
"layer_5": layers["layer_5"],
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
34
training/coqui_stt_training/evaluate.py
Executable file → Normal file
34
training/coqui_stt_training/evaluate.py
Executable file → Normal file
@ -13,12 +13,13 @@ from six.moves import zip
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from .deepspeech_model import create_model
|
||||
from .util.augmentations import NormalizeSampleRate
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.config import (
|
||||
Config,
|
||||
create_progressbar,
|
||||
initialize_globals,
|
||||
initialize_globals_from_cli,
|
||||
log_error,
|
||||
log_progress,
|
||||
)
|
||||
@ -26,8 +27,6 @@ from .util.evaluate_tools import calculate_and_print_report, save_samples_json
|
||||
from .util.feeding import create_dataset
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
@ -168,25 +167,26 @@ def evaluate(test_csvs, create_model):
|
||||
return samples
|
||||
|
||||
|
||||
def main():
|
||||
initialize_globals()
|
||||
|
||||
if not Config.test_files:
|
||||
log_error(
|
||||
"You need to specify what files to use for evaluation via "
|
||||
"the --test_files flag."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from .train import ( # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
create_model,
|
||||
)
|
||||
def test():
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
samples = evaluate(Config.test_files, create_model)
|
||||
|
||||
if Config.test_output_file:
|
||||
save_samples_json(samples, Config.test_output_file)
|
||||
|
||||
|
||||
def main():
|
||||
initialize_globals_from_cli()
|
||||
check_ctcdecoder_version()
|
||||
|
||||
if not Config.test_files:
|
||||
raise RuntimeError(
|
||||
"You need to specify what files to use for evaluation via "
|
||||
"the --test_files flag."
|
||||
)
|
||||
|
||||
test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
216
training/coqui_stt_training/export.py
Normal file
216
training/coqui_stt_training/export.py
Normal file
@ -0,0 +1,216 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = (
|
||||
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
||||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
import shutil
|
||||
|
||||
from .deepspeech_model import create_inference_graph
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.config import Config, initialize_globals_from_cli, log_error, log_info
|
||||
from .util.io import (
|
||||
open_remote,
|
||||
rmtree_remote,
|
||||
listdir_remote,
|
||||
is_remote_path,
|
||||
isdir_remote,
|
||||
)
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r"""
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
"""
|
||||
log_info("Exporting the model...")
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(
|
||||
batch_size=Config.export_batch_size,
|
||||
n_steps=Config.n_steps,
|
||||
tflite=Config.export_tflite,
|
||||
)
|
||||
|
||||
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
||||
assert graph_version > 0
|
||||
|
||||
# native_client: these nodes's names and shapes are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
||||
outputs["metadata_sample_rate"] = tf.constant(
|
||||
[Config.audio_sample_rate], name="metadata_sample_rate"
|
||||
)
|
||||
outputs["metadata_feature_win_len"] = tf.constant(
|
||||
[Config.feature_win_len], name="metadata_feature_win_len"
|
||||
)
|
||||
outputs["metadata_feature_win_step"] = tf.constant(
|
||||
[Config.feature_win_step], name="metadata_feature_win_step"
|
||||
)
|
||||
outputs["metadata_beam_width"] = tf.constant(
|
||||
[Config.export_beam_width], name="metadata_beam_width"
|
||||
)
|
||||
outputs["metadata_alphabet"] = tf.constant(
|
||||
[Config.alphabet.Serialize()], name="metadata_alphabet"
|
||||
)
|
||||
|
||||
if Config.export_language:
|
||||
outputs["metadata_language"] = tf.constant(
|
||||
[Config.export_language.encode("utf-8")], name="metadata_language"
|
||||
)
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [
|
||||
tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)
|
||||
]
|
||||
output_names_ops = [
|
||||
op.name for op in outputs.values() if isinstance(op, tf.Operation)
|
||||
]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
output_filename = Config.export_file_name + ".pb"
|
||||
if Config.remove_export:
|
||||
if isdir_remote(Config.export_dir):
|
||||
log_info("Removing old export")
|
||||
rmtree_remote(Config.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(Config.export_dir, output_filename)
|
||||
|
||||
if not is_remote_path(Config.export_dir) and not os.path.isdir(
|
||||
Config.export_dir
|
||||
):
|
||||
os.makedirs(Config.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names,
|
||||
)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph, dest_nodes=output_names
|
||||
)
|
||||
|
||||
if not Config.export_tflite:
|
||||
with open_remote(output_graph_path, "wb") as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(
|
||||
Config.export_dir, output_filename.replace(".pb", ".tflite")
|
||||
)
|
||||
|
||||
converter = tf.lite.TFLiteConverter(
|
||||
frozen_graph,
|
||||
input_tensors=inputs.values(),
|
||||
output_tensors=outputs.values(),
|
||||
)
|
||||
|
||||
if Config.export_quantize:
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open_remote(output_tflite_path, "wb") as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info("Models exported at %s" % (Config.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(
|
||||
Config.export_dir,
|
||||
"{}_{}_{}.md".format(
|
||||
Config.export_author_id,
|
||||
Config.export_model_name,
|
||||
Config.export_model_version,
|
||||
),
|
||||
)
|
||||
|
||||
model_runtime = "tflite" if Config.export_tflite else "tensorflow"
|
||||
with open_remote(metadata_fname, "w") as f:
|
||||
f.write("---\n")
|
||||
f.write("author: {}\n".format(Config.export_author_id))
|
||||
f.write("model_name: {}\n".format(Config.export_model_name))
|
||||
f.write("model_version: {}\n".format(Config.export_model_version))
|
||||
f.write("contact_info: {}\n".format(Config.export_contact_info))
|
||||
f.write("license: {}\n".format(Config.export_license))
|
||||
f.write("language: {}\n".format(Config.export_language))
|
||||
f.write("runtime: {}\n".format(model_runtime))
|
||||
f.write("min_stt_version: {}\n".format(Config.export_min_stt_version))
|
||||
f.write("max_stt_version: {}\n".format(Config.export_max_stt_version))
|
||||
f.write(
|
||||
"acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n"
|
||||
)
|
||||
f.write(
|
||||
"scorer_url: <replace this with a publicly available URL of the scorer, if present>\n"
|
||||
)
|
||||
f.write("---\n")
|
||||
f.write("{}\n".format(Config.export_description))
|
||||
|
||||
log_info(
|
||||
"Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.".format(
|
||||
metadata_fname
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(
|
||||
os.path.abspath(Config.export_dir), ""
|
||||
) # Force ending '/'
|
||||
if is_remote_path(export_dir):
|
||||
log_error(
|
||||
"Cannot package remote path zip %s. Please do this manually." % export_dir
|
||||
)
|
||||
return
|
||||
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(Config.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, "zip", export_dir)
|
||||
log_info("Exported packaged model {}".format(archive))
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals_from_cli()
|
||||
|
||||
if not Config.export_dir:
|
||||
raise RuntimeError(
|
||||
"Calling export script directly but no --export_dir specified"
|
||||
)
|
||||
|
||||
if not Config.export_zip:
|
||||
# Export to folder
|
||||
export()
|
||||
else:
|
||||
if listdir_remote(Config.export_dir):
|
||||
raise RuntimeError(
|
||||
"Directory {} is not empty, please fix this.".format(Config.export_dir)
|
||||
)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -14,12 +14,14 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
import tensorflow as tf
|
||||
from coqui_stt_ctcdecoder import Scorer
|
||||
|
||||
tfv1.logging.set_verbosity(
|
||||
{
|
||||
@ -30,12 +32,15 @@ tfv1.logging.set_verbosity(
|
||||
}.get(DESIRED_LOG_LEVEL)
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder
|
||||
from six.moves import range, zip
|
||||
|
||||
from .evaluate import evaluate
|
||||
from . import evaluate
|
||||
from . import export
|
||||
from . import training_graph_inference
|
||||
from .deepspeech_model import (
|
||||
create_model,
|
||||
rnn_impl_lstmblockfusedcell,
|
||||
rnn_impl_cudnn_rnn,
|
||||
)
|
||||
from .util.augmentations import NormalizeSampleRate
|
||||
from .util.checkpoints import (
|
||||
load_graph_for_evaluation,
|
||||
@ -45,266 +50,21 @@ from .util.checkpoints import (
|
||||
from .util.config import (
|
||||
Config,
|
||||
create_progressbar,
|
||||
initialize_globals,
|
||||
initialize_globals_from_cli,
|
||||
log_debug,
|
||||
log_error,
|
||||
log_info,
|
||||
log_progress,
|
||||
log_warn,
|
||||
)
|
||||
from .util.evaluate_tools import save_samples_json
|
||||
from .util.feeding import audio_to_features, audiofile_to_features, create_dataset
|
||||
from .util.helpers import ExceptionBox, check_ctcdecoder_version
|
||||
from .util.feeding import create_dataset
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.io import (
|
||||
is_remote_path,
|
||||
isdir_remote,
|
||||
listdir_remote,
|
||||
open_remote,
|
||||
remove_remote,
|
||||
)
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
||||
|
||||
def variable_on_cpu(name, shape, initializer):
|
||||
r"""
|
||||
Next we concern ourselves with graph creation.
|
||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||
used to create a variable in CPU memory.
|
||||
"""
|
||||
# Use the /cpu:0 device for scoped operations
|
||||
with tf.device(Config.cpu_device):
|
||||
# Create or get apropos variable
|
||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||
return var
|
||||
|
||||
|
||||
def create_overlapping_windows(batch_x):
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
window_width = 2 * Config.n_context + 1
|
||||
num_channels = Config.n_input
|
||||
|
||||
# Create a constant convolution filter using an identity matrix, so that the
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(
|
||||
np.eye(window_width * num_channels).reshape(
|
||||
window_width, num_channels, window_width * num_channels
|
||||
),
|
||||
tf.float32,
|
||||
) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding="SAME")
|
||||
|
||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||
|
||||
return batch_x
|
||||
|
||||
|
||||
def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False):
|
||||
with tfv1.variable_scope(name):
|
||||
bias = variable_on_cpu("bias", [units], tfv1.zeros_initializer())
|
||||
weights = variable_on_cpu(
|
||||
"weights",
|
||||
[x.shape[-1], units],
|
||||
tfv1.keras.initializers.VarianceScaling(
|
||||
scale=1.0, mode="fan_avg", distribution="uniform"
|
||||
),
|
||||
)
|
||||
|
||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), Config.relu_clip)
|
||||
|
||||
if layer_norm:
|
||||
with tfv1.variable_scope(name):
|
||||
output = tf.contrib.layers.layer_norm(output)
|
||||
|
||||
if dropout_rate is not None:
|
||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope("cudnn_lstm/rnn/multi_rnn_cell/cell_0"):
|
||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(
|
||||
Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name="cudnn_compatible_lstm_cell",
|
||||
)
|
||||
|
||||
output, output_state = fw_cell(
|
||||
inputs=x,
|
||||
dtype=tf.float32,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||
assert (
|
||||
previous_state is None
|
||||
) # 'Passing previous state not supported with CuDNN backend'
|
||||
|
||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||
# the object it creates the variables, and then you just call it several times
|
||||
# to enable variable re-use. Because all of our code is structure in an old
|
||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||
# emulating a static function variable.
|
||||
if not rnn_impl_cudnn_rnn.cell:
|
||||
# Forward direction cell:
|
||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(
|
||||
num_layers=1,
|
||||
num_units=Config.n_cell_dim,
|
||||
input_mode="linear_input",
|
||||
direction="unidirectional",
|
||||
dtype=tf.float32,
|
||||
)
|
||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||
|
||||
output, output_state = rnn_impl_cudnn_rnn.cell(
|
||||
inputs=x, sequence_lengths=seq_length
|
||||
)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
rnn_impl_cudnn_rnn.cell = None
|
||||
|
||||
|
||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||
with tfv1.variable_scope("cudnn_lstm/rnn/multi_rnn_cell"):
|
||||
# Forward direction cell:
|
||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(
|
||||
Config.n_cell_dim,
|
||||
forget_bias=0,
|
||||
reuse=reuse,
|
||||
name="cudnn_compatible_lstm_cell",
|
||||
)
|
||||
|
||||
# Split rank N tensor into list of rank N-1 tensors
|
||||
x = [x[l] for l in range(x.shape[0])]
|
||||
|
||||
output, output_state = tfv1.nn.static_rnn(
|
||||
cell=fw_cell,
|
||||
inputs=x,
|
||||
sequence_length=seq_length,
|
||||
initial_state=previous_state,
|
||||
dtype=tf.float32,
|
||||
scope="cell_0",
|
||||
)
|
||||
|
||||
output = tf.concat(output, 0)
|
||||
|
||||
return output, output_state
|
||||
|
||||
|
||||
def create_model(
|
||||
batch_x,
|
||||
seq_length,
|
||||
dropout,
|
||||
reuse=False,
|
||||
batch_size=None,
|
||||
previous_state=None,
|
||||
overlap=True,
|
||||
rnn_impl=rnn_impl_lstmblockfusedcell,
|
||||
):
|
||||
layers = {}
|
||||
|
||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
if not batch_size:
|
||||
batch_size = tf.shape(input=batch_x)[0]
|
||||
|
||||
# Create overlapping feature windows if needed
|
||||
if overlap:
|
||||
batch_x = create_overlapping_windows(batch_x)
|
||||
|
||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||
|
||||
# Permute n_steps and batch_size
|
||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||
# Reshape to prepare input for first layer
|
||||
batch_x = tf.reshape(
|
||||
batch_x, [-1, Config.n_input + 2 * Config.n_input * Config.n_context]
|
||||
) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||
layers["input_reshaped"] = batch_x
|
||||
|
||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||
# clipped RELU activation and dropout.
|
||||
layers["layer_1"] = layer_1 = dense(
|
||||
"layer_1",
|
||||
batch_x,
|
||||
Config.n_hidden_1,
|
||||
dropout_rate=dropout[0],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
layers["layer_2"] = layer_2 = dense(
|
||||
"layer_2",
|
||||
layer_1,
|
||||
Config.n_hidden_2,
|
||||
dropout_rate=dropout[1],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
layers["layer_3"] = layer_3 = dense(
|
||||
"layer_3",
|
||||
layer_2,
|
||||
Config.n_hidden_3,
|
||||
dropout_rate=dropout[2],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||
|
||||
# Run through parametrized RNN implementation, as we use different RNNs
|
||||
# for training and inference
|
||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||
|
||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||
layers["rnn_output"] = output
|
||||
layers["rnn_output_state"] = output_state
|
||||
|
||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||
layers["layer_5"] = layer_5 = dense(
|
||||
"layer_5",
|
||||
output,
|
||||
Config.n_hidden_5,
|
||||
dropout_rate=dropout[5],
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
layers["layer_6"] = layer_6 = dense(
|
||||
"layer_6", layer_5, Config.n_hidden_6, relu=False
|
||||
)
|
||||
|
||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||
# Note, that this differs from the input in that it is time-major.
|
||||
layer_6 = tf.reshape(
|
||||
layer_6, [-1, batch_size, Config.n_hidden_6], name="raw_logits"
|
||||
)
|
||||
layers["raw_logits"] = layer_6
|
||||
|
||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||
return layer_6, layers
|
||||
|
||||
|
||||
# Accuracy and Loss
|
||||
# =================
|
||||
@ -480,50 +240,42 @@ def average_gradients(tower_gradients):
|
||||
return average_grads
|
||||
|
||||
|
||||
# Logging
|
||||
# =======
|
||||
def early_training_checks():
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Check for proper scorer early
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
del scorer
|
||||
|
||||
if (
|
||||
Config.train_files
|
||||
and Config.test_files
|
||||
and Config.load_checkpoint_dir != Config.save_checkpoint_dir
|
||||
):
|
||||
log_warn(
|
||||
"WARNING: You specified different values for --load_checkpoint_dir "
|
||||
"and --save_checkpoint_dir, but you are running training and testing "
|
||||
"in a single invocation. The testing step will respect --load_checkpoint_dir, "
|
||||
"and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. "
|
||||
"Train and test in two separate invocations, specifying the correct "
|
||||
"--load_checkpoint_dir in both cases, or use the same location "
|
||||
"for loading and saving."
|
||||
)
|
||||
|
||||
|
||||
def log_variable(variable, gradient=None):
|
||||
r"""
|
||||
We introduce a function for logging a tensor variable's current state.
|
||||
It logs scalar values for the mean, standard deviation, minimum and maximum.
|
||||
Furthermore it logs a histogram of its state and (if given) of an optimization gradient.
|
||||
def create_training_datasets() -> (
|
||||
tf.data.Dataset,
|
||||
[tf.data.Dataset],
|
||||
[tf.data.Dataset],
|
||||
):
|
||||
"""Creates training datasets from input flags.
|
||||
|
||||
Returns a single training dataset and two lists of datasets for validation
|
||||
and metrics tracking.
|
||||
"""
|
||||
name = variable.name.replace(":", "_")
|
||||
mean = tf.reduce_mean(input_tensor=variable)
|
||||
tfv1.summary.scalar(name="%s/mean" % name, tensor=mean)
|
||||
tfv1.summary.scalar(
|
||||
name="%s/sttdev" % name,
|
||||
tensor=tf.sqrt(tf.reduce_mean(input_tensor=tf.square(variable - mean))),
|
||||
)
|
||||
tfv1.summary.scalar(
|
||||
name="%s/max" % name, tensor=tf.reduce_max(input_tensor=variable)
|
||||
)
|
||||
tfv1.summary.scalar(
|
||||
name="%s/min" % name, tensor=tf.reduce_min(input_tensor=variable)
|
||||
)
|
||||
tfv1.summary.histogram(name=name, values=variable)
|
||||
if gradient is not None:
|
||||
if isinstance(gradient, tf.IndexedSlices):
|
||||
grad_values = gradient.values
|
||||
else:
|
||||
grad_values = gradient
|
||||
if grad_values is not None:
|
||||
tfv1.summary.histogram(name="%s/gradients" % name, values=grad_values)
|
||||
|
||||
|
||||
def log_grads_and_vars(grads_and_vars):
|
||||
r"""
|
||||
Let's also introduce a helper function for logging collections of gradient/variable tuples.
|
||||
"""
|
||||
for gradient, variable in grads_and_vars:
|
||||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def train():
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(
|
||||
Config.train_files,
|
||||
@ -532,13 +284,55 @@ def train():
|
||||
augmentations=Config.augmentations,
|
||||
cache_path=Config.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
|
||||
reverse=Config.reverse_train,
|
||||
limit=Config.limit_train,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
|
||||
dev_sets = []
|
||||
if Config.dev_files:
|
||||
dev_sets = [
|
||||
create_dataset(
|
||||
[source],
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
for source in Config.dev_files
|
||||
]
|
||||
|
||||
metrics_sets = []
|
||||
if Config.metrics_files:
|
||||
metrics_sets = [
|
||||
create_dataset(
|
||||
[source],
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
for source in Config.metrics_files
|
||||
]
|
||||
|
||||
return train_set, dev_sets, metrics_sets
|
||||
|
||||
|
||||
def train():
|
||||
early_training_checks()
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
|
||||
train_set, dev_sets, metrics_sets = create_training_datasets()
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(
|
||||
tfv1.data.get_output_types(train_set),
|
||||
tfv1.data.get_output_shapes(train_set),
|
||||
@ -547,44 +341,10 @@ def train():
|
||||
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if Config.dev_files:
|
||||
dev_sources = Config.dev_files
|
||||
dev_sets = [
|
||||
create_dataset(
|
||||
[source],
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
for source in dev_sources
|
||||
]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
if Config.metrics_files:
|
||||
metrics_sources = Config.metrics_files
|
||||
metrics_sets = [
|
||||
create_dataset(
|
||||
[source],
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
for source in metrics_sources
|
||||
]
|
||||
metrics_init_ops = [
|
||||
iterator.make_initializer(metrics_set) for metrics_set in metrics_sets
|
||||
]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
metrics_init_ops = [
|
||||
iterator.make_initializer(metrics_set) for metrics_set in metrics_sets
|
||||
]
|
||||
|
||||
# Dropout
|
||||
dropout_rates = [
|
||||
@ -622,7 +382,6 @@ def train():
|
||||
|
||||
# Average tower gradients across GPUs
|
||||
avg_tower_gradients = average_gradients(gradients)
|
||||
log_grads_and_vars(avg_tower_gradients)
|
||||
|
||||
# global_step is automagically incremented by the optimizer
|
||||
global_step = tfv1.train.get_or_create_global_step()
|
||||
@ -664,6 +423,11 @@ def train():
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
json.dump(Config.serialize(), fout, indent=2)
|
||||
|
||||
# Serialize alphabet alongside checkpoint
|
||||
preserved_alphabet_file = os.path.join(Config.save_checkpoint_dir, "alphabet.txt")
|
||||
with open_remote(preserved_alphabet_file, "wb") as fout:
|
||||
fout.write(Config.alphabet.SerializeText())
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug("Session opened.")
|
||||
|
||||
@ -745,9 +509,7 @@ def train():
|
||||
],
|
||||
feed_dict=feed_dict,
|
||||
)
|
||||
exception_box.raise_if_set()
|
||||
except tf.errors.OutOfRangeError:
|
||||
exception_box.raise_if_set()
|
||||
break
|
||||
|
||||
if problem_files.size > 0:
|
||||
@ -797,7 +559,7 @@ def train():
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
for source, init_op in zip(dev_sources, dev_init_ops):
|
||||
for source, init_op in zip(Config.dev_files, dev_init_ops):
|
||||
log_progress("Validating epoch %d on %s..." % (epoch, source))
|
||||
set_loss, steps = run_set("dev", epoch, init_op, dataset=source)
|
||||
dev_loss += set_loss * steps
|
||||
@ -877,7 +639,7 @@ def train():
|
||||
|
||||
if Config.metrics_files:
|
||||
# Read only metrics, not affecting best validation loss tracking
|
||||
for source, init_op in zip(metrics_sources, metrics_init_ops):
|
||||
for source, init_op in zip(Config.metrics_files, metrics_init_ops):
|
||||
log_progress("Metrics for epoch %d on %s..." % (epoch, source))
|
||||
set_loss, _ = run_set("metrics", epoch, init_op, dataset=source)
|
||||
log_progress(
|
||||
@ -895,392 +657,44 @@ def train():
|
||||
log_debug("Session closed.")
|
||||
|
||||
|
||||
def test():
|
||||
samples = evaluate(Config.test_files, create_model)
|
||||
if Config.test_output_file:
|
||||
save_samples_json(samples, Config.test_output_file)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tfv1.placeholder(
|
||||
tf.float32, [Config.audio_window_samples], "input_samples"
|
||||
)
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name="mfccs")
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in STT_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
input_tensor = tfv1.placeholder(
|
||||
tf.float32,
|
||||
[
|
||||
batch_size,
|
||||
n_steps if n_steps > 0 else None,
|
||||
2 * Config.n_context + 1,
|
||||
Config.n_input,
|
||||
],
|
||||
name="input_node",
|
||||
)
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
previous_state_c = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
|
||||
)
|
||||
previous_state_h = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
|
||||
)
|
||||
|
||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(
|
||||
previous_state_c, previous_state_h
|
||||
)
|
||||
|
||||
# One rate per layer
|
||||
no_dropout = [None] * 6
|
||||
|
||||
if tflite:
|
||||
rnn_impl = rnn_impl_static_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(
|
||||
batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not Config.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
rnn_impl=rnn_impl,
|
||||
)
|
||||
|
||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
probs = tf.nn.softmax(logits, name="logits")
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
raise NotImplementedError(
|
||||
"dynamic batch_size does not support tflite nor streaming"
|
||||
)
|
||||
if n_steps > 0:
|
||||
raise NotImplementedError(
|
||||
"dynamic batch_size expect n_steps to be dynamic too"
|
||||
)
|
||||
return (
|
||||
{
|
||||
"input": input_tensor,
|
||||
"input_lengths": seq_length,
|
||||
},
|
||||
{
|
||||
"outputs": probs,
|
||||
},
|
||||
layers,
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers["rnn_output_state"]
|
||||
new_state_c = tf.identity(new_state_c, name="new_state_c")
|
||||
new_state_h = tf.identity(new_state_h, name="new_state_h")
|
||||
|
||||
inputs = {
|
||||
"input": input_tensor,
|
||||
"previous_state_c": previous_state_c,
|
||||
"previous_state_h": previous_state_h,
|
||||
"input_samples": input_samples,
|
||||
}
|
||||
|
||||
if not Config.export_tflite:
|
||||
inputs["input_lengths"] = seq_length
|
||||
|
||||
outputs = {
|
||||
"outputs": probs,
|
||||
"new_state_c": new_state_c,
|
||||
"new_state_h": new_state_h,
|
||||
"mfccs": mfccs,
|
||||
# Expose internal layers for downstream applications
|
||||
"layer_3": layers["layer_3"],
|
||||
"layer_5": layers["layer_5"],
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
def export():
|
||||
r"""
|
||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||
"""
|
||||
log_info("Exporting the model...")
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(
|
||||
batch_size=Config.export_batch_size,
|
||||
n_steps=Config.n_steps,
|
||||
tflite=Config.export_tflite,
|
||||
)
|
||||
|
||||
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
||||
assert graph_version > 0
|
||||
|
||||
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
||||
outputs["metadata_sample_rate"] = tf.constant(
|
||||
[Config.audio_sample_rate], name="metadata_sample_rate"
|
||||
)
|
||||
outputs["metadata_feature_win_len"] = tf.constant(
|
||||
[Config.feature_win_len], name="metadata_feature_win_len"
|
||||
)
|
||||
outputs["metadata_feature_win_step"] = tf.constant(
|
||||
[Config.feature_win_step], name="metadata_feature_win_step"
|
||||
)
|
||||
outputs["metadata_beam_width"] = tf.constant(
|
||||
[Config.export_beam_width], name="metadata_beam_width"
|
||||
)
|
||||
outputs["metadata_alphabet"] = tf.constant(
|
||||
[Config.alphabet.Serialize()], name="metadata_alphabet"
|
||||
)
|
||||
|
||||
if Config.export_language:
|
||||
outputs["metadata_language"] = tf.constant(
|
||||
[Config.export_language.encode("utf-8")], name="metadata_language"
|
||||
)
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [
|
||||
tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)
|
||||
]
|
||||
output_names_ops = [
|
||||
op.name for op in outputs.values() if isinstance(op, tf.Operation)
|
||||
]
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
output_filename = Config.export_file_name + ".pb"
|
||||
if Config.remove_export:
|
||||
if isdir_remote(Config.export_dir):
|
||||
log_info("Removing old export")
|
||||
remove_remote(Config.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(Config.export_dir, output_filename)
|
||||
|
||||
if not is_remote_path(Config.export_dir) and not os.path.isdir(
|
||||
Config.export_dir
|
||||
):
|
||||
os.makedirs(Config.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
output_node_names=output_names,
|
||||
)
|
||||
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph, dest_nodes=output_names
|
||||
)
|
||||
|
||||
if not Config.export_tflite:
|
||||
with open_remote(output_graph_path, "wb") as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(
|
||||
Config.export_dir, output_filename.replace(".pb", ".tflite")
|
||||
)
|
||||
|
||||
converter = tf.lite.TFLiteConverter(
|
||||
frozen_graph,
|
||||
input_tensors=inputs.values(),
|
||||
output_tensors=outputs.values(),
|
||||
)
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||
converter.allow_custom_ops = True
|
||||
tflite_model = converter.convert()
|
||||
|
||||
with open_remote(output_tflite_path, "wb") as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info("Models exported at %s" % (Config.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(
|
||||
Config.export_dir,
|
||||
"{}_{}_{}.md".format(
|
||||
Config.export_author_id,
|
||||
Config.export_model_name,
|
||||
Config.export_model_version,
|
||||
),
|
||||
)
|
||||
|
||||
model_runtime = "tflite" if Config.export_tflite else "tensorflow"
|
||||
with open_remote(metadata_fname, "w") as f:
|
||||
f.write("---\n")
|
||||
f.write("author: {}\n".format(Config.export_author_id))
|
||||
f.write("model_name: {}\n".format(Config.export_model_name))
|
||||
f.write("model_version: {}\n".format(Config.export_model_version))
|
||||
f.write("contact_info: {}\n".format(Config.export_contact_info))
|
||||
f.write("license: {}\n".format(Config.export_license))
|
||||
f.write("language: {}\n".format(Config.export_language))
|
||||
f.write("runtime: {}\n".format(model_runtime))
|
||||
f.write("min_stt_version: {}\n".format(Config.export_min_stt_version))
|
||||
f.write("max_stt_version: {}\n".format(Config.export_max_stt_version))
|
||||
f.write(
|
||||
"acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n"
|
||||
)
|
||||
f.write(
|
||||
"scorer_url: <replace this with a publicly available URL of the scorer, if present>\n"
|
||||
)
|
||||
f.write("---\n")
|
||||
f.write("{}\n".format(Config.export_description))
|
||||
|
||||
log_info(
|
||||
"Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.".format(
|
||||
metadata_fname
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(
|
||||
os.path.abspath(Config.export_dir), ""
|
||||
) # Force ending '/'
|
||||
if is_remote_path(export_dir):
|
||||
log_error(
|
||||
"Cannot package remote path zip %s. Please do this manually." % export_dir
|
||||
)
|
||||
return
|
||||
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(Config.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, "zip", export_dir)
|
||||
log_info("Exported packaged model {}".format(archive))
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
probs = outputs["outputs"].eval(
|
||||
feed_dict={
|
||||
inputs["input"]: features,
|
||||
inputs["input_lengths"]: features_len,
|
||||
inputs["previous_state_c"]: previous_state_c,
|
||||
inputs["previous_state_h"]: previous_state_h,
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
probs = np.squeeze(probs)
|
||||
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(
|
||||
probs,
|
||||
Config.alphabet,
|
||||
Config.beam_width,
|
||||
scorer=scorer,
|
||||
cutoff_prob=Config.cutoff_prob,
|
||||
cutoff_top_n=Config.cutoff_top_n,
|
||||
)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def early_training_checks():
|
||||
# Check for proper scorer early
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
del scorer
|
||||
|
||||
if (
|
||||
Config.train_files
|
||||
and Config.test_files
|
||||
and Config.load_checkpoint_dir != Config.save_checkpoint_dir
|
||||
):
|
||||
log_warn(
|
||||
"WARNING: You specified different values for --load_checkpoint_dir "
|
||||
"and --save_checkpoint_dir, but you are running training and testing "
|
||||
"in a single invocation. The testing step will respect --load_checkpoint_dir, "
|
||||
"and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. "
|
||||
"Train and test in two separate invocations, specifying the correct "
|
||||
"--load_checkpoint_dir in both cases, or use the same location "
|
||||
"for loading and saving."
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
initialize_globals()
|
||||
early_training_checks()
|
||||
initialize_globals_from_cli()
|
||||
|
||||
def deprecated_msg(prefix):
|
||||
return (
|
||||
f"{prefix} Using the training script as a generic driver for all training "
|
||||
"related functionality is deprecated and will be removed soon. Use "
|
||||
"the specific scripts: train.py/evaluate.py/export.py/training_graph_inference.py."
|
||||
)
|
||||
|
||||
if Config.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
train()
|
||||
else:
|
||||
log_warn(deprecated_msg("Calling training script without --train_files."))
|
||||
|
||||
if Config.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if Config.export_dir and not Config.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if Config.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
Config.export_tflite = True
|
||||
|
||||
if listdir_remote(Config.export_dir):
|
||||
log_error(
|
||||
"Directory {} is not empty, please fix this.".format(Config.export_dir)
|
||||
log_warn(
|
||||
deprecated_msg(
|
||||
"Specifying --test_files when calling train.py script. Use evaluate.py."
|
||||
)
|
||||
sys.exit(1)
|
||||
)
|
||||
evaluate.test()
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
if Config.export_dir:
|
||||
log_warn(
|
||||
deprecated_msg(
|
||||
"Specifying --export_dir when calling train.py script. Use export.py."
|
||||
)
|
||||
)
|
||||
export.export()
|
||||
|
||||
if Config.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(Config.one_shot_infer)
|
||||
log_warn(
|
||||
deprecated_msg(
|
||||
"Specifying --one_shot_infer when calling train.py script. Use training_graph_inference.py."
|
||||
)
|
||||
)
|
||||
traning_graph_inference.do_single_file_inference(Config.one_shot_infer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
87
training/coqui_stt_training/training_graph_inference.py
Normal file
87
training/coqui_stt_training/training_graph_inference.py
Normal file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0
|
||||
DESIRED_LOG_LEVEL = (
|
||||
sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3"
|
||||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from coqui_stt_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from .deepspeech_model import create_inference_graph, create_overlapping_windows
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.config import Config, initialize_globals_from_cli, log_error
|
||||
from .util.feeding import audiofile_to_features
|
||||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
features_len = tf.expand_dims(features_len, 0)
|
||||
|
||||
# Evaluate
|
||||
features = create_overlapping_windows(features).eval(session=session)
|
||||
features_len = features_len.eval(session=session)
|
||||
|
||||
probs = outputs["outputs"].eval(
|
||||
feed_dict={
|
||||
inputs["input"]: features,
|
||||
inputs["input_lengths"]: features_len,
|
||||
inputs["previous_state_c"]: previous_state_c,
|
||||
inputs["previous_state_h"]: previous_state_h,
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
probs = np.squeeze(probs)
|
||||
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(
|
||||
probs,
|
||||
Config.alphabet,
|
||||
Config.beam_width,
|
||||
scorer=scorer,
|
||||
cutoff_prob=Config.cutoff_prob,
|
||||
cutoff_top_n=Config.cutoff_top_n,
|
||||
)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def main():
|
||||
initialize_globals_from_cli()
|
||||
|
||||
if Config.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(Config.one_shot_infer)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Calling training_graph_inference script directly but no --one_shot_infer input audio file specified"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
194
training/coqui_stt_training/util/auto_input.py
Normal file
194
training/coqui_stt_training/util/auto_input.py
Normal file
@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas
|
||||
from tqdm import tqdm
|
||||
|
||||
from .io import open_remote
|
||||
from .sample_collections import samples_from_sources
|
||||
from coqui_stt_ctcdecoder import Alphabet
|
||||
|
||||
|
||||
def create_alphabet_from_sources(sources: [str]) -> ([str], Alphabet):
|
||||
"""Generate an Alphabet from characters in given sources.
|
||||
|
||||
sources: List of paths to input sources (CSV, SDB).
|
||||
|
||||
Returns a 2-tuple with list of characters and Alphabet instance.
|
||||
"""
|
||||
characters = set()
|
||||
for sample in tqdm(samples_from_sources(sources)):
|
||||
characters |= set(sample.transcript)
|
||||
characters = list(sorted(characters))
|
||||
alphabet = Alphabet()
|
||||
alphabet.InitFromLabels(characters)
|
||||
return characters, alphabet
|
||||
|
||||
|
||||
def _get_sample_size(population_size):
|
||||
"""calculates the sample size for a 99% confidence and 1% margin of error"""
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
numerator = (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2
|
||||
)
|
||||
sample_size = 0
|
||||
for train_size in range(population_size, 0, -1):
|
||||
denominator = 1 + (z_score ** 2 * fraction_picking * (1 - fraction_picking)) / (
|
||||
margin_of_error ** 2 * train_size
|
||||
)
|
||||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
return sample_size
|
||||
|
||||
|
||||
def _split_sets(samples: pandas.DataFrame, sample_size):
|
||||
"""
|
||||
randomply split the datasets into train, validation, and test sets where the size of the
|
||||
validation and test sets are determined by the `get_sample_size` function.
|
||||
"""
|
||||
samples = samples.sample(frac=1).reset_index(drop=True)
|
||||
|
||||
train_beg = 0
|
||||
train_end = len(samples) - 2 * sample_size
|
||||
|
||||
dev_beg = train_end
|
||||
dev_end = train_end + sample_size
|
||||
|
||||
test_beg = dev_end
|
||||
test_end = len(samples)
|
||||
|
||||
return (
|
||||
samples[train_beg:train_end],
|
||||
samples[dev_beg:dev_end],
|
||||
samples[test_beg:test_end],
|
||||
)
|
||||
|
||||
|
||||
def create_datasets_from_auto_input(
|
||||
auto_input_dataset: Path, alphabet_config_path: Optional[Path]
|
||||
) -> (Path, Path, Path, Path):
|
||||
"""Creates training datasets from --auto_input_dataset flag.
|
||||
|
||||
auto_input_dataset: Path to input CSV or folder containing CSV.
|
||||
|
||||
Returns paths to generated train set, dev set and test set, and the path
|
||||
to the alphabet file, either generated from the data, existing alongside
|
||||
data, or specified manually by the user.
|
||||
"""
|
||||
if auto_input_dataset.is_dir():
|
||||
auto_input_dir = auto_input_dataset
|
||||
all_csvs = list(auto_input_dataset.glob("*.csv"))
|
||||
if not all_csvs:
|
||||
raise RuntimeError(
|
||||
"--auto_input_dataset is a directory but no CSV file was found "
|
||||
"inside of it. Either make sure a CSV file is in the directory "
|
||||
"or specify the file it directly."
|
||||
)
|
||||
|
||||
non_subsets = [f for f in all_csvs if f.stem not in ("train", "dev", "test")]
|
||||
if len(non_subsets) == 1:
|
||||
auto_input_csv = non_subsets[0]
|
||||
elif len(non_subsets) > 1:
|
||||
non_subsets_fmt = ", ".join(str(s) for s in non_subsets)
|
||||
raise RuntimeError(
|
||||
"--auto_input_dataset is a directory but there are multiple CSV "
|
||||
f"files not matching a subset name (train/dev/test): {non_subsets_fmt}. "
|
||||
"Either remove extraneous CSV files or specify the correct file "
|
||||
"to use for dataset formatting directly instead of the directory."
|
||||
)
|
||||
# else (empty) -> fall through, sets already present and get picked up below
|
||||
else:
|
||||
auto_input_dir = auto_input_dataset.parent
|
||||
auto_input_csv = auto_input_dataset
|
||||
|
||||
train_set_path = auto_input_dir / "train.csv"
|
||||
dev_set_path = auto_input_dir / "dev.csv"
|
||||
test_set_path = auto_input_dir / "test.csv"
|
||||
|
||||
if train_set_path.exists() != dev_set_path.exists() != test_set_path.exists():
|
||||
raise RuntimeError(
|
||||
"Specifying --auto_input_dataset with some generated files present "
|
||||
"and some missing. Either all three sets (train.csv, dev.csv, test.csv) "
|
||||
"should exist alongside {auto_input_csv} (in which case they will be used), "
|
||||
"or none of those files should exist (in which case they will be generated.)"
|
||||
)
|
||||
|
||||
print(f"I Processing --auto_input_dataset input: {auto_input_csv}...")
|
||||
df = pandas.read_csv(auto_input_csv)
|
||||
|
||||
if set(df.columns) < set(("wav_filename", "wav_filesize", "transcript")):
|
||||
raise RuntimeError(
|
||||
"Missing columns in --auto_input_dataset CSV. STT training inputs "
|
||||
"require wav_filename, wav_filesize, and transcript columns."
|
||||
)
|
||||
|
||||
dev_test_size = _get_sample_size(len(df))
|
||||
if dev_test_size == 0:
|
||||
if len(df) >= 2:
|
||||
dev_test_size = 1
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"--auto_input_dataset dataset is too small for automatic splitting "
|
||||
"into sets. Specify a larger input dataset or split it manually."
|
||||
)
|
||||
|
||||
data_characters = sorted(list(set("".join(df["transcript"].values))))
|
||||
alphabet_alongside_data_path = auto_input_dir / "alphabet.txt"
|
||||
if alphabet_config_path:
|
||||
alphabet = Alphabet(str(alphabet_config_path))
|
||||
if not alphabet.CanEncode("".join(data_characters)):
|
||||
raise RuntimeError(
|
||||
"--alphabet_config_path was specified alongside --auto_input_dataset, "
|
||||
"but alphabet contents don't match dataset transcripts. Make sure the "
|
||||
"alphabet covers all transcripts or leave --alphabet_config_path "
|
||||
"unspecified so that one will be generated automatically."
|
||||
)
|
||||
print(f"I Using specified --alphabet_config_path: {alphabet_config_path}")
|
||||
generated_alphabet_path = alphabet_config_path
|
||||
elif alphabet_alongside_data_path.exists():
|
||||
alphabet = Alphabet(str(alphabet_alongside_data_path))
|
||||
if not alphabet.CanEncode("".join(data_characters)):
|
||||
raise RuntimeError(
|
||||
"alphabet.txt exists alongside --auto_input_dataset file, but "
|
||||
"alphabet contents don't match dataset transcripts. Make sure the "
|
||||
"alphabet covers all transcripts or remove alphabet.txt file "
|
||||
"from the data folderso that one will be generated automatically."
|
||||
)
|
||||
generated_alphabet_path = alphabet_alongside_data_path
|
||||
print(f"I Using existing alphabet file: {alphabet_alongside_data_path}")
|
||||
else:
|
||||
alphabet = Alphabet()
|
||||
alphabet.InitFromLabels(data_characters)
|
||||
generated_alphabet_path = auto_input_dir / "alphabet.txt"
|
||||
print(
|
||||
f"I Saved generated alphabet with characters ({data_characters}) into {generated_alphabet_path}"
|
||||
)
|
||||
with open_remote(str(generated_alphabet_path), "wb") as fout:
|
||||
fout.write(alphabet.SerializeText())
|
||||
|
||||
# If splits don't already exist, generate and save them.
|
||||
# We check above that all three splits either exist or don't exist together,
|
||||
# so we can check a single one for existence here.
|
||||
if not train_set_path.exists():
|
||||
train_set, dev_set, test_set = _split_sets(df, dev_test_size)
|
||||
print(f"I Generated train set size: {len(train_set)} samples.")
|
||||
print(f"I Generated validation set size: {len(dev_set)} samples.")
|
||||
print(f"I Generated test set size: {len(test_set)} samples.")
|
||||
|
||||
print(f"I Writing train set to {train_set_path}")
|
||||
train_set.to_csv(train_set_path, index=False)
|
||||
|
||||
print(f"I Writing dev set to {dev_set_path}")
|
||||
dev_set.to_csv(dev_set_path, index=False)
|
||||
|
||||
print(f"I Writing test set to {test_set_path}")
|
||||
test_set.to_csv(test_set_path, index=False)
|
||||
else:
|
||||
print("I Generated splits found alongside --auto_input_dataset, using them.")
|
||||
|
||||
return train_set_path, dev_set_path, test_set_path, generated_alphabet_path
|
437
training/coqui_stt_training/util/config.py
Executable file → Normal file
437
training/coqui_stt_training/util/config.py
Executable file → Normal file
@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import progressbar
|
||||
@ -13,6 +14,7 @@ from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
from .augmentations import NormalizeSampleRate, parse_augmentations
|
||||
from .auto_input import create_alphabet_from_sources, create_datasets_from_auto_input
|
||||
from .gpu import get_available_gpus
|
||||
from .helpers import parse_file_size
|
||||
from .io import path_exists_remote
|
||||
@ -22,7 +24,7 @@ class _ConfigSingleton:
|
||||
_config = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not _ConfigSingleton._config:
|
||||
if _ConfigSingleton._config is None:
|
||||
raise RuntimeError("Global configuration not yet initialized.")
|
||||
if not hasattr(_ConfigSingleton._config, name):
|
||||
raise RuntimeError(
|
||||
@ -36,10 +38,263 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class _SttConfig(Coqpit):
|
||||
def __post_init__(self):
|
||||
# Augmentations
|
||||
self.augmentations = parse_augmentations(self.augment)
|
||||
if self.augmentations:
|
||||
print(f"Parsed augmentations: {self.augmentations}")
|
||||
if self.augmentations and self.feature_cache and self.cache_for_epochs == 0:
|
||||
print(
|
||||
"Due to your feature-cache settings, augmentations of "
|
||||
"the first epoch will be repeated on all following epochs. "
|
||||
"This may lead to unintended over-fitting. "
|
||||
"You can use --cache_for_epochs <n_epochs> to invalidate "
|
||||
"the cache after a given number of epochs."
|
||||
)
|
||||
|
||||
if self.normalize_sample_rate:
|
||||
self.augmentations = [NormalizeSampleRate(self.audio_sample_rate)] + self[
|
||||
"augmentations"
|
||||
]
|
||||
|
||||
# Caching
|
||||
if self.cache_for_epochs == 1:
|
||||
print(
|
||||
"--cache_for_epochs == 1 is (re-)creating the feature cache "
|
||||
"on every epoch but will never use it. You can either set "
|
||||
"--cache_for_epochs > 1, or not use feature caching at all."
|
||||
)
|
||||
|
||||
# Read-buffer
|
||||
self.read_buffer = parse_file_size(self.read_buffer)
|
||||
|
||||
# Set default dropout rates
|
||||
if self.dropout_rate2 < 0:
|
||||
self.dropout_rate2 = self.dropout_rate
|
||||
if self.dropout_rate3 < 0:
|
||||
self.dropout_rate3 = self.dropout_rate
|
||||
if self.dropout_rate6 < 0:
|
||||
self.dropout_rate6 = self.dropout_rate
|
||||
|
||||
# Checkpoint dir logic #
|
||||
if self.checkpoint_dir:
|
||||
# checkpoint_dir always overrides {save,load}_checkpoint_dir
|
||||
self.save_checkpoint_dir = self.checkpoint_dir
|
||||
self.load_checkpoint_dir = self.checkpoint_dir
|
||||
else:
|
||||
if not self.save_checkpoint_dir:
|
||||
self.save_checkpoint_dir = xdg.save_data_path(
|
||||
os.path.join("stt", "checkpoints")
|
||||
)
|
||||
if not self.load_checkpoint_dir:
|
||||
self.load_checkpoint_dir = xdg.save_data_path(
|
||||
os.path.join("stt", "checkpoints")
|
||||
)
|
||||
|
||||
if self.load_train not in ["last", "best", "init", "auto"]:
|
||||
self.load_train = "auto"
|
||||
|
||||
if self.load_evaluate not in ["last", "best", "auto"]:
|
||||
self.load_evaluate = "auto"
|
||||
|
||||
# Set default summary dir
|
||||
if not self.summary_dir:
|
||||
self.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
||||
|
||||
# Standard session configuration that'll be used for all new sessions.
|
||||
self.session_config = tfv1.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=self.log_placement,
|
||||
inter_op_parallelism_threads=self.inter_op_parallelism_threads,
|
||||
intra_op_parallelism_threads=self.intra_op_parallelism_threads,
|
||||
gpu_options=tfv1.GPUOptions(allow_growth=self.use_allow_growth),
|
||||
)
|
||||
|
||||
# CPU device
|
||||
self.cpu_device = "/cpu:0"
|
||||
|
||||
# Available GPU devices
|
||||
self.available_devices = get_available_gpus(self.session_config)
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if not self.available_devices:
|
||||
self.available_devices = [self.cpu_device]
|
||||
|
||||
# If neither `--alphabet_config_path` nor `--bytes_output_mode` were specified,
|
||||
# look for alphabet file alongside loaded checkpoint.
|
||||
loaded_checkpoint_alphabet_file = os.path.join(
|
||||
self.load_checkpoint_dir, "alphabet.txt"
|
||||
)
|
||||
saved_checkpoint_alphabet_file = os.path.join(
|
||||
self.save_checkpoint_dir, "alphabet.txt"
|
||||
)
|
||||
|
||||
if not (
|
||||
bool(self.auto_input_dataset)
|
||||
!= (self.train_files or self.dev_files or self.test_files)
|
||||
):
|
||||
raise RuntimeError(
|
||||
"When using --auto_input_dataset, do not specify --train_files, "
|
||||
"--dev_files, or --test_files."
|
||||
)
|
||||
|
||||
if self.auto_input_dataset:
|
||||
(
|
||||
gen_train,
|
||||
gen_dev,
|
||||
gen_test,
|
||||
gen_alphabet,
|
||||
) = create_datasets_from_auto_input(
|
||||
Path(self.auto_input_dataset),
|
||||
Path(self.alphabet_config_path) if self.alphabet_config_path else None,
|
||||
)
|
||||
self.train_files = [str(gen_train)]
|
||||
self.dev_files = [str(gen_dev)]
|
||||
self.test_files = [str(gen_test)]
|
||||
self.alphabet_config_path = str(gen_alphabet)
|
||||
|
||||
if self.bytes_output_mode and self.alphabet_config_path:
|
||||
raise RuntimeError(
|
||||
"You cannot set --alphabet_config_path *and* --bytes_output_mode"
|
||||
)
|
||||
elif self.bytes_output_mode:
|
||||
self.alphabet = UTF8Alphabet()
|
||||
elif self.alphabet_config_path:
|
||||
self.alphabet = Alphabet(self.alphabet_config_path)
|
||||
elif os.path.exists(loaded_checkpoint_alphabet_file):
|
||||
print(
|
||||
"I --alphabet_config_path not specified, but found an alphabet file "
|
||||
f"alongside specified checkpoint ({loaded_checkpoint_alphabet_file}). "
|
||||
"Will use this alphabet file for this run."
|
||||
)
|
||||
self.alphabet = Alphabet(loaded_checkpoint_alphabet_file)
|
||||
elif self.train_files and self.dev_files and self.test_files:
|
||||
# If all subsets are in the same folder and there's an alphabet file
|
||||
# alongside them, use it.
|
||||
self.alphabet = None
|
||||
sources = self.train_files + self.dev_files + self.test_files
|
||||
parents = set(Path(p).parent for p in sources)
|
||||
if len(parents) == 1:
|
||||
possible_alphabet = list(parents)[0] / "alphabet.txt"
|
||||
if possible_alphabet.exists():
|
||||
print(
|
||||
"I --alphabet_config_path not specified, but all input "
|
||||
"datasets are present and in the same folder (--train_files, "
|
||||
"--dev_files and --test_files), and an alphabet.txt file "
|
||||
f"was found alongside the sets ({possible_alphabet}). "
|
||||
"Will use this alphabet file for this run."
|
||||
)
|
||||
self.alphabet = Alphabet(str(possible_alphabet))
|
||||
|
||||
if not self.alphabet:
|
||||
# Generate alphabet automatically from input dataset, but only if
|
||||
# fully specified, to avoid confusion in case a missing set has extra
|
||||
# characters.
|
||||
print(
|
||||
"I --alphabet_config_path not specified, but all input datasets are "
|
||||
"present (--train_files, --dev_files, --test_files). An alphabet "
|
||||
"will be generated automatically from the data and placed alongside "
|
||||
f"the checkpoint ({saved_checkpoint_alphabet_file})."
|
||||
)
|
||||
characters, alphabet = create_alphabet_from_sources(sources)
|
||||
print(f"I Generated alphabet characters: {characters}.")
|
||||
self.alphabet = alphabet
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Missing --alphabet_config_path flag. Couldn't find an alphabet file\n"
|
||||
"alongside checkpoint, and input datasets are not fully specified\n"
|
||||
"(--train_files, --dev_files, --test_files), so can't generate an alphabet.\n"
|
||||
"Either specify an alphabet file or fully specify the dataset, so one will\n"
|
||||
"be generated automatically."
|
||||
)
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
||||
# For an explanation of the meaning of the geometric constants
|
||||
# please refer to doc/Geometry.md
|
||||
|
||||
# Number of MFCC features
|
||||
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
||||
|
||||
# The number of frames in the context
|
||||
self.n_context = (
|
||||
9 # TODO: Determine the optimal value using a validation data set
|
||||
)
|
||||
|
||||
# Number of units in hidden layers
|
||||
self.n_hidden = self.n_hidden
|
||||
|
||||
self.n_hidden_1 = self.n_hidden
|
||||
|
||||
self.n_hidden_2 = self.n_hidden
|
||||
|
||||
self.n_hidden_5 = self.n_hidden
|
||||
|
||||
# LSTM cell state dimension
|
||||
self.n_cell_dim = self.n_hidden
|
||||
|
||||
# The number of units in the third layer, which feeds in to the LSTM
|
||||
self.n_hidden_3 = self.n_cell_dim
|
||||
|
||||
# Dims in last layer = number of characters in alphabet plus one
|
||||
# +1 for CTC blank label
|
||||
self.n_hidden_6 = self.alphabet.GetSize() + 1
|
||||
|
||||
# Size of audio window in samples
|
||||
if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0:
|
||||
raise RuntimeError(
|
||||
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_len value or resample your audio accordingly."
|
||||
"".format(
|
||||
self.feature_win_len,
|
||||
self.feature_win_len / 1000,
|
||||
self.audio_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
self.audio_window_samples = self.audio_sample_rate * (
|
||||
self.feature_win_len / 1000
|
||||
)
|
||||
|
||||
# Stride for feature computations in samples
|
||||
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
|
||||
raise RuntimeError(
|
||||
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_step value or resample your audio accordingly."
|
||||
"".format(
|
||||
self.feature_win_step,
|
||||
self.feature_win_step / 1000,
|
||||
self.audio_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
self.audio_step_samples = self.audio_sample_rate * (
|
||||
self.feature_win_step / 1000
|
||||
)
|
||||
|
||||
if self.one_shot_infer and not path_exists_remote(self.one_shot_infer):
|
||||
raise RuntimeError(
|
||||
"Path specified in --one_shot_infer is not a valid file."
|
||||
)
|
||||
|
||||
if self.train_cudnn and self.load_cudnn:
|
||||
raise RuntimeError(
|
||||
"Trying to use --train_cudnn, but --load_cudnn "
|
||||
"was also specified. The --load_cudnn flag is only "
|
||||
"needed when converting a CuDNN RNN checkpoint to "
|
||||
"a CPU-capable graph. If your system is capable of "
|
||||
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
||||
"checkpoint normally with --save_checkpoint_dir."
|
||||
)
|
||||
|
||||
# sphinx-doc: training_ref_flags_start
|
||||
train_files: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata=dict(
|
||||
help="space-separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run."
|
||||
help="space-separated list of files specifying the datasets used for training. Multiple files will get merged. If empty, training will not be run."
|
||||
),
|
||||
)
|
||||
dev_files: List[str] = field(
|
||||
@ -60,6 +315,12 @@ class _SttConfig(Coqpit):
|
||||
help="space-separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed."
|
||||
),
|
||||
)
|
||||
auto_input_dataset: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="path to a single CSV file to use for training. Cannot be specified alongside --train_files, --dev_files, --test_files. Training/validation/testing subsets will be automatically generated from the input, alongside with an alphabet file, if not already present.",
|
||||
),
|
||||
)
|
||||
|
||||
read_buffer: str = field(
|
||||
default="1MB",
|
||||
@ -297,7 +558,11 @@ class _SttConfig(Coqpit):
|
||||
default=False, metadata=dict(help="whether to remove old exported models")
|
||||
)
|
||||
export_tflite: bool = field(
|
||||
default=False, metadata=dict(help="export a graph ready for TF Lite engine")
|
||||
default=True, metadata=dict(help="export a graph ready for TF Lite engine")
|
||||
)
|
||||
export_quantize: bool = field(
|
||||
default=True,
|
||||
metadata=dict(help="export a quantized model (optimized for size)"),
|
||||
)
|
||||
n_steps: int = field(
|
||||
default=16,
|
||||
@ -472,7 +737,7 @@ class _SttConfig(Coqpit):
|
||||
),
|
||||
)
|
||||
alphabet_config_path: str = field(
|
||||
default="data/alphabet.txt",
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format."
|
||||
),
|
||||
@ -539,167 +804,17 @@ class _SttConfig(Coqpit):
|
||||
help="the number of trials to run during hyperparameter optimization."
|
||||
),
|
||||
)
|
||||
|
||||
def check_values(self):
|
||||
c = asdict(self)
|
||||
check_argument("alphabet_config_path", c, is_path=True)
|
||||
check_argument("one_shot_infer", c, is_path=True)
|
||||
# sphinx-doc: training_ref_flags_end
|
||||
|
||||
|
||||
def initialize_globals():
|
||||
c = _SttConfig()
|
||||
c.parse_args(arg_prefix="")
|
||||
def initialize_globals_from_cli():
|
||||
c = _SttConfig.init_from_argparse(arg_prefix="")
|
||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
||||
# Augmentations
|
||||
c.augmentations = parse_augmentations(c.augment)
|
||||
print(f"Parsed augmentations from flags: {c.augmentations}")
|
||||
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
||||
print(
|
||||
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
||||
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
|
||||
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
|
||||
)
|
||||
|
||||
if c.normalize_sample_rate:
|
||||
c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[
|
||||
"augmentations"
|
||||
]
|
||||
|
||||
# Caching
|
||||
if c.cache_for_epochs == 1:
|
||||
print(
|
||||
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
|
||||
)
|
||||
|
||||
# Read-buffer
|
||||
c.read_buffer = parse_file_size(c.read_buffer)
|
||||
|
||||
# Set default dropout rates
|
||||
if c.dropout_rate2 < 0:
|
||||
c.dropout_rate2 = c.dropout_rate
|
||||
if c.dropout_rate3 < 0:
|
||||
c.dropout_rate3 = c.dropout_rate
|
||||
if c.dropout_rate6 < 0:
|
||||
c.dropout_rate6 = c.dropout_rate
|
||||
|
||||
# Set default checkpoint dir
|
||||
if not c.checkpoint_dir:
|
||||
c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
|
||||
|
||||
if c.load_train not in ["last", "best", "init", "auto"]:
|
||||
c.load_train = "auto"
|
||||
|
||||
if c.load_evaluate not in ["last", "best", "auto"]:
|
||||
c.load_evaluate = "auto"
|
||||
|
||||
# Set default summary dir
|
||||
if not c.summary_dir:
|
||||
c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
||||
|
||||
# Standard session configuration that'll be used for all new sessions.
|
||||
c.session_config = tfv1.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=c.log_placement,
|
||||
inter_op_parallelism_threads=c.inter_op_parallelism_threads,
|
||||
intra_op_parallelism_threads=c.intra_op_parallelism_threads,
|
||||
gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth),
|
||||
)
|
||||
|
||||
# CPU device
|
||||
c.cpu_device = "/cpu:0"
|
||||
|
||||
# Available GPU devices
|
||||
c.available_devices = get_available_gpus(c.session_config)
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if not c.available_devices:
|
||||
c.available_devices = [c.cpu_device]
|
||||
|
||||
if c.bytes_output_mode:
|
||||
c.alphabet = UTF8Alphabet()
|
||||
else:
|
||||
c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path))
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
||||
# For an explanation of the meaning of the geometric constants, please refer to
|
||||
# doc/Geometry.md
|
||||
|
||||
# Number of MFCC features
|
||||
c.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
||||
|
||||
# The number of frames in the context
|
||||
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
|
||||
|
||||
# Number of units in hidden layers
|
||||
c.n_hidden = c.n_hidden
|
||||
|
||||
c.n_hidden_1 = c.n_hidden
|
||||
|
||||
c.n_hidden_2 = c.n_hidden
|
||||
|
||||
c.n_hidden_5 = c.n_hidden
|
||||
|
||||
# LSTM cell state dimension
|
||||
c.n_cell_dim = c.n_hidden
|
||||
|
||||
# The number of units in the third layer, which feeds in to the LSTM
|
||||
c.n_hidden_3 = c.n_cell_dim
|
||||
|
||||
# Units in the sixth layer = number of characters in the target language plus one
|
||||
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
||||
|
||||
# Size of audio window in samples
|
||||
if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0:
|
||||
log_error(
|
||||
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_len value or resample your audio accordingly."
|
||||
"".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000)
|
||||
|
||||
# Stride for feature computations in samples
|
||||
if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0:
|
||||
log_error(
|
||||
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_step value or resample your audio accordingly."
|
||||
"".format(
|
||||
c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000)
|
||||
|
||||
if c.one_shot_infer:
|
||||
if not path_exists_remote(c.one_shot_infer):
|
||||
log_error("Path specified in --one_shot_infer is not a valid file.")
|
||||
sys.exit(1)
|
||||
|
||||
if c.train_cudnn and c.load_cudnn:
|
||||
log_error(
|
||||
"Trying to use --train_cudnn, but --load_cudnn "
|
||||
"was also specified. The --load_cudnn flag is only "
|
||||
"needed when converting a CuDNN RNN checkpoint to "
|
||||
"a CPU-capable graph. If your system is capable of "
|
||||
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
||||
"checkpoint normally with --save_checkpoint_dir."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# If separate save and load flags were not specified, default to load and save
|
||||
# from the same dir.
|
||||
if not c.save_checkpoint_dir:
|
||||
c.save_checkpoint_dir = c.checkpoint_dir
|
||||
|
||||
if not c.load_checkpoint_dir:
|
||||
c.load_checkpoint_dir = c.checkpoint_dir
|
||||
|
||||
def initialize_globals_from_args(**override_args):
|
||||
# Update Config with new args
|
||||
c = _SttConfig(**override_args)
|
||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from os import makedirs, path
|
||||
|
||||
from tqdm import tqdm
|
||||
import progressbar
|
||||
import requests
|
||||
|
||||
@ -26,17 +27,11 @@ def maybe_download(archive_name, target_dir, archive_url):
|
||||
print('No archive "%s" - downloading...' % archive_path)
|
||||
req = requests.get(archive_url, stream=True)
|
||||
total_size = int(req.headers.get("content-length", 0))
|
||||
done = 0
|
||||
with open_remote(archive_path, "wb") as f:
|
||||
bar = progressbar.ProgressBar(
|
||||
max_value=total_size if total_size > 0 else progressbar.UnknownLength,
|
||||
widgets=SIMPLE_BAR,
|
||||
)
|
||||
|
||||
for data in req.iter_content(1024 * 1024):
|
||||
done += len(data)
|
||||
f.write(data)
|
||||
bar.update(done)
|
||||
with tqdm(total=total_size) as bar:
|
||||
for data in req.iter_content(1024 * 1024):
|
||||
f.write(data)
|
||||
bar.update(len(data))
|
||||
else:
|
||||
print('Found archive "%s" - not downloading.' % archive_path)
|
||||
return archive_path
|
||||
|
@ -12,7 +12,7 @@ import tensorflow as tf
|
||||
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
|
||||
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
|
||||
from .config import Config
|
||||
from .helpers import MEGABYTE, remember_exception
|
||||
from .helpers import MEGABYTE
|
||||
from .sample_collections import samples_from_sources
|
||||
from .text import text_to_char_array
|
||||
|
||||
@ -138,7 +138,6 @@ def create_dataset(
|
||||
train_phase=False,
|
||||
reverse=False,
|
||||
limit=0,
|
||||
exception_box=None,
|
||||
process_ahead=None,
|
||||
buffering=1 * MEGABYTE,
|
||||
):
|
||||
@ -197,7 +196,7 @@ def create_dataset(
|
||||
)
|
||||
|
||||
dataset = tf.data.Dataset.from_generator(
|
||||
remember_exception(generate_values, exception_box),
|
||||
generate_values,
|
||||
output_types=(
|
||||
tf.string,
|
||||
tf.float32,
|
||||
@ -223,7 +222,6 @@ def split_audio_file(
|
||||
aggressiveness=3,
|
||||
outlier_duration_ms=10000,
|
||||
outlier_batch_size=1,
|
||||
exception_box=None,
|
||||
):
|
||||
def generate_values():
|
||||
frames = read_frames_from_file(audio_path)
|
||||
@ -240,7 +238,7 @@ def split_audio_file(
|
||||
def create_batch_set(bs, criteria):
|
||||
return (
|
||||
tf.data.Dataset.from_generator(
|
||||
remember_exception(generate_values, exception_box),
|
||||
generate_values,
|
||||
output_types=(tf.int32, tf.int32, tf.float32),
|
||||
)
|
||||
.map(to_mfccs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
0
training/coqui_stt_training/util/gpu.py
Executable file → Normal file
0
training/coqui_stt_training/util/gpu.py
Executable file → Normal file
@ -19,14 +19,19 @@ ValueRange = namedtuple("ValueRange", "start end r")
|
||||
|
||||
|
||||
def parse_file_size(file_size):
|
||||
file_size = file_size.lower().strip()
|
||||
if len(file_size) == 0:
|
||||
return 0
|
||||
n = int(keep_only_digits(file_size))
|
||||
if file_size[-1] == "b":
|
||||
file_size = file_size[:-1]
|
||||
e = file_size[-1]
|
||||
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
|
||||
if type(file_size) is str:
|
||||
file_size = file_size.lower().strip()
|
||||
if len(file_size) == 0:
|
||||
return 0
|
||||
n = int(keep_only_digits(file_size))
|
||||
if file_size[-1] == "b":
|
||||
file_size = file_size[:-1]
|
||||
e = file_size[-1]
|
||||
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
|
||||
elif type(file_size) is int:
|
||||
return file_size
|
||||
else:
|
||||
raise ValueError("file_size not of type 'int' or 'str'")
|
||||
|
||||
|
||||
def keep_only_digits(txt):
|
||||
@ -158,35 +163,6 @@ class LimitingPool:
|
||||
self.pool.close()
|
||||
|
||||
|
||||
class ExceptionBox:
|
||||
"""Helper class for passing-back and re-raising an exception from inside a TensorFlow dataset generator.
|
||||
Used in conjunction with `remember_exception`."""
|
||||
|
||||
def __init__(self):
|
||||
self.exception = None
|
||||
|
||||
def raise_if_set(self):
|
||||
if self.exception is not None:
|
||||
exception = self.exception
|
||||
self.exception = None
|
||||
raise exception # pylint: disable = raising-bad-type
|
||||
|
||||
|
||||
def remember_exception(iterable, exception_box=None):
|
||||
"""Wraps a TensorFlow dataset generator for catching its actual exceptions
|
||||
that would otherwise just interrupt iteration w/o bubbling up."""
|
||||
|
||||
def do_iterate():
|
||||
try:
|
||||
yield from iterable()
|
||||
except StopIteration:
|
||||
return
|
||||
except Exception as ex: # pylint: disable = broad-except
|
||||
exception_box.exception = ex
|
||||
|
||||
return iterable if exception_box is None else do_iterate
|
||||
|
||||
|
||||
def get_value_range(value, target_type):
|
||||
"""
|
||||
This function converts all possible supplied values for augmentation
|
||||
|
@ -90,3 +90,10 @@ def remove_remote(filename):
|
||||
"""
|
||||
# Conditional import
|
||||
return gfile.remove(filename)
|
||||
|
||||
|
||||
def rmtree_remote(foldername):
|
||||
"""
|
||||
Wrapper that can remove local and remote directories like `gs://...`
|
||||
"""
|
||||
return gfile.rmtree(foldername)
|
||||
|
@ -20,7 +20,7 @@ from multiprocessing import Process, cpu_count
|
||||
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||
from coqui_stt_training.util.audio import AudioFile
|
||||
from coqui_stt_training.util.config import Config, initialize_globals
|
||||
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
||||
from coqui_stt_training.util.feeding import split_audio_file
|
||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
||||
from coqui_stt_training.util.logging import (
|
||||
@ -42,7 +42,8 @@ def transcribe_file(audio_path, tlog_path):
|
||||
)
|
||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||
|
||||
initialize_globals()
|
||||
initialize_globals_from_cli()
|
||||
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
|
Loading…
Reference in New Issue
Block a user