Merge pull request #1921 from coqui-ai/tflite-only

Remove full TF backend from native client and CI
This commit is contained in:
Reuben Morais 2021-07-29 17:44:38 +02:00 committed by GitHub
commit 58a8e813e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 291 additions and 1166 deletions

View File

@ -37,6 +37,7 @@ async function getGoodArtifacts(client, owner, repo, releaseId, name) {
async function main() {
try {
const token = core.getInput("github_token", { required: true });
const [owner, repo] = core.getInput("repo", { required: true }).split("/");
const path = core.getInput("path", { required: true });
const name = core.getInput("name");
@ -44,6 +45,7 @@ async function main() {
const releaseTag = core.getInput("release-tag");
const OctokitWithThrottling = GitHub.plugin(throttling);
const client = new OctokitWithThrottling({
auth: token,
throttle: {
onRateLimit: (retryAfter, options) => {
console.log(
@ -54,6 +56,9 @@ async function main() {
if (options.request.retryCount <= 2) {
console.log(`Retrying after ${retryAfter} seconds!`);
return true;
} else {
console.log("Exhausted 2 retries");
core.setFailed("Exhausted 2 retries");
}
},
onAbuseLimit: (retryAfter, options) => {
@ -61,6 +66,7 @@ async function main() {
console.log(
`Abuse detected for request ${options.method} ${options.url}`
);
core.setFailed(`GitHub REST API Abuse detected for request ${options.method} ${options.url}`)
},
},
});
@ -101,6 +107,7 @@ async function main() {
await Download(artifact.url, dir, {
headers: {
"Accept": "application/octet-stream",
"Authorization": `token ${token}`,
},
});
}

View File

@ -5,11 +5,8 @@ inputs:
description: "Target arch for loading script (host/armv7/aarch64)"
required: false
default: "host"
flavor:
description: "Build flavor"
required: true
runs:
using: "composite"
steps:
- run: ./ci_scripts/${{ inputs.arch }}-build.sh ${{ inputs.flavor }}
- run: ./ci_scripts/${{ inputs.arch }}-build.sh
shell: bash

View File

@ -1,9 +1,6 @@
name: "Python binding"
description: "Binding a python binding"
inputs:
build_flavor:
description: "Python package name"
required: true
numpy_build:
description: "NumPy build dependecy"
required: true
@ -46,9 +43,6 @@ runs:
set -xe
PROJECT_NAME="stt"
if [ "${{ inputs.build_flavor }}" = "tflite" ]; then
PROJECT_NAME="stt-tflite"
fi
OS=$(uname)
if [ "${OS}" = "Linux" ]; then

View File

@ -4,9 +4,6 @@ inputs:
runtime:
description: "Runtime to use for running test"
required: true
build-flavor:
description: "Running against TF or TFLite"
required: true
model-kind:
description: "Running against CI baked or production model"
required: true
@ -22,10 +19,7 @@ runs:
- run: |
set -xe
build=""
if [ "${{ inputs.build-flavor }}" = "tflite" ]; then
build="_tflite"
fi
model_kind=""
if [ "${{ inputs.model-kind }}" = "prod" ]; then

File diff suppressed because it is too large Load Diff

View File

@ -55,23 +55,6 @@ maybe_install_xldd()
fi
}
# Checks whether we run a patched version of bazel.
# Patching is required to dump computeKey() parameters to .ckd files
# See bazel.patch
# Return 0 (success exit code) on patched version, 1 on release version
is_patched_bazel()
{
bazel_version=$(bazel version | grep 'Build label:' | cut -d':' -f2)
bazel shutdown
if [ -z "${bazel_version}" ]; then
return 0;
else
return 1;
fi;
}
verify_bazel_rebuild()
{
bazel_explain_file="$1"

View File

@ -9,21 +9,15 @@ do_bazel_build()
cd ${DS_TFDIR}
eval "export ${BAZEL_ENV_FLAGS}"
if [ "${_opt_or_dbg}" = "opt" ]; then
if is_patched_bazel; then
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-tf.tar -T -
fi;
fi;
bazel ${BAZEL_OUTPUT_USER_ROOT} build \
-s --explain bazel_monolithic.log --verbose_explanations --experimental_strict_action_env --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
-s --explain bazel_explain.log --verbose_explanations \
--experimental_strict_action_env \
--workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" \
-c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
if [ "${_opt_or_dbg}" = "opt" ]; then
if is_patched_bazel; then
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-ds.tar -T -
fi;
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_monolithic.log"
fi;
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_explain.log"
fi
}
shutdown_bazel()

View File

@ -1,26 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_material
export PATH=${CI_TMP_DIR}/ds/:$PATH
check_versions
run_prod_inference_tests "${bitrate}"

View File

@ -1,24 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
export PATH=${CI_TMP_DIR}/ds/:$PATH
check_versions
run_all_inference_tests
run_multi_inference_tests
run_cpp_only_inference_tests
run_hotword_tests

View File

@ -1,20 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_material "${CI_TMP_DIR}/ds"
export PATH=${CI_TMP_DIR}/ds/:$PATH
check_versions
ensure_cuda_usage "$2"
run_basic_inference_tests

View File

@ -1,48 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_data
node --version
npm --version
symlink_electron
export_node_bin_path
which electron
which node
if [ "${OS}" = "Linux" ]; then
export DISPLAY=':99.0'
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
xvfb_process=$!
fi
node --version
stt --version
check_runtime_electronjs
run_electronjs_prod_inference_tests "${bitrate}"
if [ "${OS}" = "Linux" ]; then
sleep 1
sudo kill -9 ${xvfb_process} || true
fi

View File

@ -1,41 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
node --version
npm --version
symlink_electron
export_node_bin_path
which electron
which node
if [ "${OS}" = "Linux" ]; then
export DISPLAY=':99.0'
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
xvfb_process=$!
fi
node --version
stt --version
check_runtime_electronjs
run_electronjs_inference_tests
if [ "${OS}" = "Linux" ]; then
sleep 1
sudo kill -9 ${xvfb_process} || true
fi

View File

@ -2,8 +2,6 @@
set -xe
runtime=$1
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/build-utils.sh
@ -15,10 +13,7 @@ BAZEL_TARGETS="
//native_client:generate_scorer_package
"
if [ "${runtime}" = "tflite" ]; then
BAZEL_BUILD_TFLITE="--define=runtime=tflite"
fi;
BAZEL_BUILD_FLAGS="${BAZEL_BUILD_TFLITE} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
BAZEL_BUILD_FLAGS="${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
BAZEL_ENV_FLAGS="TF_NEED_CUDA=0"
SYSTEM_TARGET=host

View File

@ -1,30 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_data
node --version
npm --version
export_node_bin_path
check_runtime_nodejs
run_prod_inference_tests "${bitrate}"
run_js_streaming_prod_inference_tests "${bitrate}"

View File

@ -1,25 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
node --version
npm --version
export_node_bin_path
check_runtime_nodejs
run_all_inference_tests
run_js_streaming_inference_tests
run_hotword_tests

View File

@ -30,10 +30,15 @@ package_native_client()
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib"
fi;
if [ -f "${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib" ]; then
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so.if.lib"
fi;
${TAR} --verbose -cf - \
--transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \
-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \
-C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \
-C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so \
${win_lib} \
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
-C ${stt_dir}/ LICENSE \
@ -94,5 +99,8 @@ package_libstt_as_zip()
echo "Please specify artifact name."
fi;
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" ${tensorflow_dir}/bazel-bin/native_client/libstt.so
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" \
${tensorflow_dir}/bazel-bin/native_client/libstt.so \
${tensorflow_dir}/bazel-bin/native_client/libkenlm.so \
${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so
}

View File

@ -1,29 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_material
export_py_bin_path
which stt
stt --version
run_prod_inference_tests "${bitrate}"
run_prod_concurrent_stream_tests "${bitrate}"

View File

@ -1,21 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
export_py_bin_path
which stt
stt --version
run_all_inference_tests
run_hotword_tests

View File

@ -6,7 +6,7 @@ set -o pipefail
source $(dirname $0)/tf-vars.sh
pushd ${DS_ROOT_TASK}/tensorflow/
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s --explain bazel_monolithic_tf.log --verbose_explanations --experimental_strict_action_env --config=monolithic"
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s"
# Start a bazel process to ensure reliability on Windows and avoid:
# FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing.
@ -23,13 +23,10 @@ pushd ${DS_ROOT_TASK}/tensorflow/
case "$1" in
"--windows-cpu")
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIBSTT} ${BUILD_TARGET_LITE_LIB} --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh"
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
;;
"--linux-cpu"|"--darwin-cpu")
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_LITE_LIB}
;;
"--linux-cuda"|"--windows-cuda")
eval "export ${TF_CUDA_FLAGS}" && (echo "" | TF_NEED_CUDA=1 ./configure) && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_CUDA_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API}
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
;;
"--linux-armv7")
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}

View File

@ -6,26 +6,17 @@ source $(dirname $0)/tf-vars.sh
mkdir -p ${CI_ARTIFACTS_DIR} || true
cp ${DS_ROOT_TASK}/tensorflow/bazel_*.log ${CI_ARTIFACTS_DIR} || true
OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin"
for output_bin in \
tensorflow/lite/experimental/c/libtensorflowlite_c.so \
tensorflow/tools/graph_transforms/transform_graph \
tensorflow/tools/graph_transforms/summarize_graph \
tensorflow/tools/benchmark/benchmark_model \
tensorflow/contrib/util/convert_graphdef_memmapped_format \
tensorflow/lite/toco/toco;
tensorflow/lite/libtensorflow.so \
tensorflow/lite/libtensorflow.so.if.lib \
;
do
if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then
cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/
fi;
done;
if [ -f "${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model" ]; then
cp ${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model ${CI_ARTIFACTS_DIR}/lite_benchmark_model
fi
done
# It seems that bsdtar and gnutar are behaving a bit differently on the way
# they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/

View File

@ -5,12 +5,7 @@ set -ex
source $(dirname $0)/tf-vars.sh
install_android=
install_cuda=
case "$1" in
"--linux-cuda"|"--windows-cuda")
install_cuda=yes
;;
"--android-armv7"|"--android-arm64")
install_android=yes
;;
@ -29,11 +24,6 @@ download()
mkdir -p ${DS_ROOT_TASK}/dls || true
download $BAZEL_URL $BAZEL_SHA256
if [ ! -z "${install_cuda}" ]; then
download $CUDA_URL $CUDA_SHA256
download $CUDNN_URL $CUDNN_SHA256
fi;
if [ ! -z "${install_android}" ]; then
download $ANDROID_NDK_URL $ANDROID_NDK_SHA256
download $ANDROID_SDK_URL $ANDROID_SDK_SHA256
@ -63,30 +53,6 @@ bazel version
bazel shutdown
if [ ! -z "${install_cuda}" ]; then
# Install CUDA and CuDNN
mkdir -p ${DS_ROOT_TASK}/STT/CUDA/ || true
pushd ${DS_ROOT_TASK}
CUDA_FILE=`basename ${CUDA_URL}`
PERL5LIB=. sh ${DS_ROOT_TASK}/dls/${CUDA_FILE} --silent --override --toolkit --toolkitpath=${DS_ROOT_TASK}/STT/CUDA/ --defaultroot=${DS_ROOT_TASK}/STT/CUDA/
CUDNN_FILE=`basename ${CUDNN_URL}`
tar xvf ${DS_ROOT_TASK}/dls/${CUDNN_FILE} --strip-components=1 -C ${DS_ROOT_TASK}/STT/CUDA/
popd
LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH
# We might lack libcuda.so.1 symlink, let's fix as upstream does:
# https://github.com/tensorflow/tensorflow/pull/13811/files?diff=split#diff-2352449eb75e66016e97a591d3f0f43dR96
if [ ! -h "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1" ]; then
ln -s "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so" "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1"
fi;
else
echo "No CUDA/CuDNN to install"
fi
if [ ! -z "${install_android}" ]; then
mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true
ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}`

View File

@ -9,13 +9,6 @@ if [ "${OS}" = "Linux" ]; then
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh
BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed
CUDA_URL=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
CUDA_SHA256=e7c22dc21278eb1b82f34a60ad7640b41ad3943d929bebda3008b72536855d31
# From https://gitlab.com/nvidia/cuda/blob/centos7/10.1/devel/cudnn7/Dockerfile
CUDNN_URL=http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.1-linux-x64-v7.6.0.64.tgz
CUDNN_SHA256=e956c6f9222fcb867a10449cfc76dee5cfd7c7531021d95fe9586d7e043b57d7
ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip
ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd
@ -48,8 +41,6 @@ elif [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe
BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9
CUDA_INSTALL_DIRECTORY=$(cygpath 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1')
TAR=/usr/bin/tar.exe
elif [ "${OS}" = "Darwin" ]; then
if [ -z "${CI_TASK_DIR}" -o -z "${CI_ARTIFACTS_DIR}" ]; then
@ -89,7 +80,6 @@ fi;
export PATH
if [ "${OS}" = "Linux" ]; then
export LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/
export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/
fi;
@ -160,27 +150,15 @@ export BAZEL_OUTPUT_USER_ROOT
NVCC_COMPUTE="3.5"
### Define build parameters/env variables that we will re-ues in sourcing scripts.
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_PATHS=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
else
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_PATHS=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
fi
BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_CUDA_FLAGS="--config=cuda"
if [ "${OS}" = "Linux" ]; then
# constexpr usage in tensorflow's absl dep fails badly because of gcc-5
# so let's skip that
BAZEL_CUDA_FLAGS="${BAZEL_CUDA_FLAGS} --copt=-DNO_CONSTEXPR_FOR_YOU=1"
fi
BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl --copt=-fvisibility=hidden"
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl"
fi
if [ "${OS}" = "Darwin" ]; then
@ -189,11 +167,5 @@ fi
### Define build targets that we will re-ues in sourcing scripts.
BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc"
BUILD_TARGET_GRAPH_TRANSFORMS="//tensorflow/tools/graph_transforms:transform_graph"
BUILD_TARGET_GRAPH_SUMMARIZE="//tensorflow/tools/graph_transforms:summarize_graph"
BUILD_TARGET_GRAPH_BENCHMARK="//tensorflow/tools/benchmark:benchmark_model"
#BUILD_TARGET_CONVERT_MMAP="//tensorflow/contrib/util:convert_graphdef_memmapped_format"
BUILD_TARGET_TOCO="//tensorflow/lite/toco:toco"
BUILD_TARGET_LITE_BENCHMARK="//tensorflow/lite/tools/benchmark:benchmark_model"
BUILD_TARGET_LITE_LIB="//tensorflow/lite/c:libtensorflowlite_c.so"
BUILD_TARGET_LITE_LIB="//tensorflow/lite:libtensorflowlite.so"
BUILD_TARGET_LIBSTT="//native_client:libstt.so"

View File

@ -7,65 +7,46 @@ Here we maintain the list of supported platforms for deployment.
*Note that 🐸STT currently only provides packages for CPU deployment with Python 3.5 or higher on Linux. We're working to get the rest of our usually supported packages back up and running as soon as possible.*
Linux / AMD64 without GPU
Linux / AMD64
^^^^^^^^^^^^^^^^^^^^^^^^^
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
* Ubuntu 14.04+ (glibc >= 2.19, libstdc++6 >= 4.8)
* Full TensorFlow runtime (``stt`` packages)
* TensorFlow Lite runtime (``stt-tflite`` packages)
Linux / AMD64 with GPU
^^^^^^^^^^^^^^^^^^^^^^
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
* Ubuntu 14.04+ (glibc >= 2.19, libstdc++6 >= 4.8)
* CUDA 10.0 (and capable GPU)
* Full TensorFlow runtime (``stt`` packages)
* TensorFlow Lite runtime (``stt-tflite`` packages)
* glibc >= 2.24, libstdc++6 >= 6.3
* TensorFlow Lite runtime
Linux / ARMv7
^^^^^^^^^^^^^
* Cortex-A53 compatible ARMv7 SoC with Neon support
* Raspbian Buster-compatible distribution
* TensorFlow Lite runtime (``stt-tflite`` packages)
* TensorFlow Lite runtime
Linux / Aarch64
^^^^^^^^^^^^^^^
* Cortex-A72 compatible Aarch64 SoC
* ARMbian Buster-compatible distribution
* TensorFlow Lite runtime (``stt-tflite`` packages)
* TensorFlow Lite runtime
Android / ARMv7
^^^^^^^^^^^^^^^
* ARMv7 SoC with Neon support
* Android 7.0-10.0
* NDK API level >= 21
* TensorFlow Lite runtime (``stt-tflite`` packages)
* TensorFlow Lite runtime
Android / Aarch64
^^^^^^^^^^^^^^^^^
* Aarch64 SoC
* Android 7.0-10.0
* NDK API level >= 21
* TensorFlow Lite runtime (``stt-tflite`` packages)
* TensorFlow Lite runtime
macOS / AMD64
^^^^^^^^^^^^^
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
* macOS >= 10.10
* Full TensorFlow runtime (``stt`` packages)
* TensorFlow Lite runtime (``stt-tflite`` packages)
* TensorFlow Lite runtime
Windows / AMD64 without GPU
Windows / AMD64
^^^^^^^^^^^^^^^^^^^^^^^^^^^
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
* Windows Server >= 2012 R2 ; Windows >= 8.1
* Full TensorFlow runtime (``stt`` packages)
* TensorFlow Lite runtime (``stt-tflite`` packages)
Windows / AMD64 with GPU
^^^^^^^^^^^^^^^^^^^^^^^^
* x86-64 CPU with AVX/FMA (one can rebuild without AVX/FMA, but it might slow down performance)
* Windows Server >= 2012 R2 ; Windows >= 8.1
* CUDA 10.0 (and capable GPU)
* Full TensorFlow runtime (``stt`` packages)
* TensorFlow Lite runtime (``stt-tflite`` packages)
* TensorFlow Lite runtime

View File

@ -1,22 +1,9 @@
# Description: Coqui STT native client library.
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object", "tf_copts", "lrt_if_needed")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "lrt_if_needed")
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
load(
"@org_tensorflow//tensorflow/lite:build_def.bzl",
"tflite_copts",
"tflite_linkopts",
)
config_setting(
name = "tflite",
define_values = {
"runtime": "tflite",
},
)
config_setting(
name = "rpi3",
@ -85,7 +72,7 @@ LINUX_LINKOPTS = [
"-Wl,-export-dynamic",
]
tf_cc_shared_object(
cc_binary(
name = "libkenlm.so",
srcs = glob([
"kenlm/lm/*.hh",
@ -107,8 +94,19 @@ tf_cc_shared_object(
}),
defines = ["KENLM_MAX_ORDER=6"],
includes = ["kenlm"],
framework_so = [],
linkopts = [],
linkshared = 1,
linkopts = select({
"//tensorflow:ios": [
"-Wl,-install_name,@rpath/libkenlm.so",
],
"//tensorflow:macos": [
"-Wl,-install_name,@rpath/libkenlm.so",
],
"//tensorflow:windows": [],
"//conditions:default": [
"-Wl,-soname,libkenlm.so",
],
}),
)
cc_library(
@ -132,6 +130,20 @@ cc_library(
copts = ["-fexceptions"],
)
cc_library(
name="tflite",
hdrs = [
"//tensorflow/lite:model.h",
"//tensorflow/lite/kernels:register.h",
"//tensorflow/lite/tools/evaluation:utils.h",
],
srcs = [
"//tensorflow/lite:libtensorflowlite.so",
],
includes = ["tensorflow"],
deps = ["//tensorflow/lite:libtensorflowlite.so"],
)
cc_library(
name = "coqui_stt_bundle",
srcs = [
@ -142,17 +154,10 @@ cc_library(
"modelstate.h",
"workspace_status.cc",
"workspace_status.h",
] + select({
"//native_client:tflite": [
"tflitemodelstate.h",
"tflitemodelstate.cc",
],
"//conditions:default": [
"tfmodelstate.h",
"tfmodelstate.cc",
],
}) + DECODER_SOURCES,
copts = tf_copts(allow_exceptions=True) + select({
] + DECODER_SOURCES,
copts = select({
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
"//tensorflow:windows": ["/w"],
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
@ -161,9 +166,6 @@ cc_library(
"-Wno-sign-compare",
"-fvisibility=hidden",
],
}) + select({
"//native_client:tflite": ["-DUSE_TFLITE"],
"//conditions:default": ["-UUSE_TFLITE"],
}),
linkopts = lrt_if_needed() + select({
"//tensorflow:macos": [],
@ -174,64 +176,32 @@ cc_library(
# Bazel is has too strong opinions about static linking, so it's
# near impossible to get it to link a DLL against another DLL on Windows.
# We simply force the linker option manually here as a hacky fix.
"//tensorflow:windows": ["bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib"],
"//tensorflow:windows": [
"bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib",
"bazel-out/x64_windows-opt/bin/tensorflow/lite/libtensorflowlite.so.if.lib",
],
"//conditions:default": [],
}) + tflite_linkopts() + DECODER_LINKOPTS,
}) + DECODER_LINKOPTS,
includes = DECODER_INCLUDES,
deps = select({
"//native_client:tflite": [
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/tools/evaluation:utils",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//third_party/eigen3",
#"//tensorflow/core:all_kernels",
### => Trying to be more fine-grained
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
### CPU only build, libstt.so file size reduced by ~50%
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:bias_op", # BiasAdd
"//tensorflow/core/kernels:cast_op", # Cast
"//tensorflow/core/kernels:concat_op", # ConcatV2
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
"//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
"//tensorflow/core/kernels:matmul_op", # MatMul
"//tensorflow/core/kernels:reduction_ops", # Max
"//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:no_op", # NoOp
"//tensorflow/core/kernels:pack_op", # Pack
"//tensorflow/core/kernels:sequence_ops", # Range
"//tensorflow/core/kernels:relu_op", # Relu
"//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:softmax_op", # Softmax
"//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:rnn_ops", # BlockLSTM
# And we also need the op libs for these ops used in the model:
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
"//tensorflow/core:rnn_ops_op_lib", # BlockLSTM
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
"//tensorflow/core:no_op_op_lib", # NoOp
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
],
}) + if_cuda([
"//tensorflow/core:core",
]) + [":kenlm"],
deps = [":kenlm", ":tflite"],
)
tf_cc_shared_object(
cc_binary(
name = "libstt.so",
deps = [":coqui_stt_bundle"],
linkshared = 1,
linkopts = select({
"//tensorflow:ios": [
"-Wl,-install_name,@rpath/libstt.so",
],
"//tensorflow:macos": [
"-Wl,-install_name,@rpath/libstt.so",
],
"//tensorflow:windows": [],
"//conditions:default": [
"-Wl,-soname,libstt.so",
],
}),
)
ios_static_framework(

View File

@ -45,14 +45,14 @@ workspace_status.cc:
# Enforce PATH here because swig calls from build_ext looses track of some
# variables over several runs
bindings: clean-keep-third-party workspace_status.cc $(DS_SWIG_DEP)
python -m pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==45.0.0
python -m pip install --quiet $(PYTHON_PACKAGES) wheel setuptools
DISTUTILS_USE_SDK=1 PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN_DIR):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
find temp_build -type f -name "*.o" -delete
DISTUTILS_USE_SDK=1 AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
rm -rf temp_build
bindings-debug: clean-keep-third-party workspace_status.cc $(DS_SWIG_DEP)
python -m pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==45.0.0
python -m pip install --quiet $(PYTHON_PACKAGES) wheel setuptools
DISTUTILS_USE_SDK=1 PATH=$(DS_SWIG_BIN_PATH):$(TOOLCHAIN_DIR):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) LIBEXE=$(LIBEXE) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
$(GENERATE_DEBUG_SYMS)
find temp_build -type f -name "*.o" -delete

View File

@ -20,8 +20,8 @@ endif
STT_BIN := stt$(PLATFORM_EXE_SUFFIX)
CFLAGS_STT := -std=c++11 -o $(STT_BIN)
LINK_STT := -lstt -lkenlm
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client
LINK_STT := -lstt -lkenlm -ltensorflowlite
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client -L${TFDIR}/bazel-bin/tensorflow/lite
ifeq ($(TARGET),host)
TOOLCHAIN :=
@ -50,9 +50,6 @@ else
SOX_LDFLAGS := `pkg-config --libs sox`
endif # OS others
PYTHON_PACKAGES := numpy${NUMPY_BUILD_VERSION}
ifeq ($(OS),Linux)
PYTHON_PLATFORM_NAME ?= --plat-name manylinux1_x86_64
endif
endif
ifeq ($(findstring _NT,$(OS)),_NT)
@ -61,7 +58,7 @@ TOOL_CC := cl.exe
TOOL_CXX := cl.exe
TOOL_LD := link.exe
TOOL_LIBEXE := lib.exe
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib")
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib")
LINK_PATH_STT :=
CFLAGS_STT := -nologo -Fe$(STT_BIN)
SOX_CFLAGS :=
@ -185,7 +182,7 @@ define copy_missing_libs
new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \
missing_libs="$$missing_libs $$new_missing"; \
elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \
missing_libs="libstt.so libkenlm.so"; \
missing_libs="libstt.so libkenlm.so libtensorflowlite.so"; \
else \
missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \
fi; \
@ -237,7 +234,7 @@ DS_SWIG_ENV := SWIG_LIB="$(SWIG_LIB)" PATH="$(DS_SWIG_BIN_PATH):${PATH}"
$(DS_SWIG_BIN_PATH)/swig:
mkdir -p $(SWIG_ROOT)
wget -O - "$(SWIG_DIST_URL)" | tar -C $(SWIG_ROOT) -zxf -
curl -sSL "$(SWIG_DIST_URL)" | tar -C $(SWIG_ROOT) -zxf -
ln -s $(DS_SWIG_BIN) $(DS_SWIG_BIN_PATH)/$(SWIG_BIN)
ds-swig: $(DS_SWIG_BIN_PATH)/swig

View File

@ -27,12 +27,14 @@
"libraries": [
"../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib",
"../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib",
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib",
],
},
{
"libraries": [
"../../../tensorflow/bazel-bin/native_client/libstt.so",
"../../../tensorflow/bazel-bin/native_client/libkenlm.so",
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so",
],
},
],

View File

@ -9,7 +9,7 @@ bindings-clean:
# Enforce PATH here because swig calls from build_ext looses track of some
# variables over several runs
bindings-build: ds-swig
pip3 install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==45.0.0
pip3 install --quiet $(PYTHON_PACKAGES) wheel setuptools
DISTUTILS_USE_SDK=1 PATH=$(TOOLCHAIN_DIR):$(DS_SWIG_BIN_PATH):$$PATH SWIG_LIB="$(SWIG_LIB)" AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED) $(RPATH_PYTHON)" MODEL_LDFLAGS="$(LDFLAGS_DIRS)" MODEL_LIBS="$(LIBS)" $(PYTHON_PATH) $(PYTHON_SYSCONFIGDATA) $(NUMPY_INCLUDE) python3 ./setup.py build_ext $(PYTHON_PLATFORM_NAME)
MANIFEST.in: bindings-build

View File

@ -14,13 +14,7 @@
#include "modelstate.h"
#include "workspace_status.h"
#ifndef USE_TFLITE
#include "tfmodelstate.h"
#else
#include "tflitemodelstate.h"
#endif // USE_TFLITE
#include "ctcdecode/ctc_beam_search_decoder.h"
#ifdef __ANDROID__
@ -282,13 +276,7 @@ STT_CreateModel(const char* aModelPath,
return STT_ERR_NO_MODEL;
}
std::unique_ptr<ModelState> model(
#ifndef USE_TFLITE
new TFModelState()
#else
new TFLiteModelState()
#endif
);
std::unique_ptr<ModelState> model(new TFLiteModelState());
if (!model) {
std::cerr << "Could not allocate model state." << std::endl;

View File

@ -1,263 +0,0 @@
#include "tfmodelstate.h"
#include "workspace_status.h"
using namespace tensorflow;
using std::vector;
TFModelState::TFModelState()
: ModelState()
, mmap_env_(nullptr)
, session_(nullptr)
{
}
TFModelState::~TFModelState()
{
if (session_) {
Status status = session_->Close();
if (!status.ok()) {
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
}
}
}
int
TFModelState::init(const char* model_path)
{
int err = ModelState::init(model_path);
if (err != STT_ERR_OK) {
return err;
}
Status status;
SessionOptions options;
mmap_env_.reset(new MemmappedEnv(Env::Default()));
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
if (!is_mmap) {
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
} else {
status = mmap_env_->InitializeFromFile(model_path);
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_INIT_MMAP;
}
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(::OptimizerOptions::L0);
options.env = mmap_env_.get();
}
Session* session;
status = NewSession(options, &session);
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_INIT_SESS;
}
session_.reset(session);
if (is_mmap) {
status = ReadBinaryProto(mmap_env_.get(),
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&graph_def_);
} else {
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
}
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_READ_PROTOBUF;
}
status = session_->Create(graph_def_);
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_CREATE_SESS;
}
std::vector<tensorflow::Tensor> version_output;
status = session_->Run({}, {
"metadata_version"
}, {}, &version_output);
if (!status.ok()) {
std::cerr << "Unable to fetch graph version: " << status << std::endl;
return STT_ERR_MODEL_INCOMPATIBLE;
}
int graph_version = version_output[0].scalar<int>()();
if (graph_version < ds_graph_version()) {
std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
<< ds_graph_version() << "). See "
<< "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
<< "for more information" << std::endl;
return STT_ERR_MODEL_INCOMPATIBLE;
}
std::vector<tensorflow::Tensor> metadata_outputs;
status = session_->Run({}, {
"metadata_sample_rate",
"metadata_feature_win_len",
"metadata_feature_win_step",
"metadata_beam_width",
"metadata_alphabet",
}, {}, &metadata_outputs);
if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl;
return STT_ERR_MODEL_INCOMPATIBLE;
}
sample_rate_ = metadata_outputs[0].scalar<int>()();
int win_len_ms = metadata_outputs[1].scalar<int>()();
int win_step_ms = metadata_outputs[2].scalar<int>()();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
int beam_width = metadata_outputs[3].scalar<int>()();
beam_width_ = (unsigned int)(beam_width);
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
if (err != 0) {
return STT_ERR_INVALID_ALPHABET;
}
assert(sample_rate_ > 0);
assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0);
assert(beam_width_ > 0);
assert(alphabet_.GetSize() > 0);
for (int i = 0; i < graph_def_.node_size(); ++i) {
NodeDef node = graph_def_.node(i);
if (node.name() == "input_node") {
const auto& shape = node.attr().at("shape").shape();
n_steps_ = shape.dim(1).size();
n_context_ = (shape.dim(2).size()-1)/2;
n_features_ = shape.dim(3).size();
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
} else if (node.name() == "previous_state_c") {
const auto& shape = node.attr().at("shape").shape();
state_size_ = shape.dim(1).size();
} else if (node.name() == "logits_shape") {
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
continue;
}
int final_dim_size = logits_shape.vec<int>()(2) - 1;
if (final_dim_size != alphabet_.GetSize()) {
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
<< "has size " << alphabet_.GetSize()
<< ", but model has " << final_dim_size
<< " classes in its output. Make sure you're passing an alphabet "
<< "file with the same size as the one used for training."
<< std::endl;
return STT_ERR_INVALID_ALPHABET;
}
}
}
if (n_context_ == -1 || n_features_ == -1) {
std::cerr << "Error: Could not infer input shape from model file. "
<< "Make sure input_node is a 4D tensor with shape "
<< "[batch_size=1, time, window_size, n_features]."
<< std::endl;
return STT_ERR_INVALID_SHAPE;
}
return STT_ERR_OK;
}
Tensor
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
{
Tensor ret(DT_FLOAT, shape);
auto ret_mapped = ret.flat<float>();
int i;
for (i = 0; i < vec.size(); ++i) {
ret_mapped(i) = vec[i];
}
for (; i < shape.num_elements(); ++i) {
ret_mapped(i) = 0.f;
}
return ret;
}
void
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
{
auto tensor_mapped = tensor.flat<float>();
if (num_elements == -1) {
num_elements = tensor.shape().num_elements();
}
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor_mapped(i));
}
}
void
TFModelState::infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor input_lengths(DT_INT32, TensorShape({1}));
input_lengths.scalar<int>()() = n_frames;
vector<Tensor> outputs;
Status status = session_->Run(
{
{"input_node", input},
{"input_lengths", input_lengths},
{"previous_state_c", previous_state_c_t},
{"previous_state_h", previous_state_h_t}
},
{"logits", "new_state_c", "new_state_h"},
{},
&outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
state_c_output.clear();
state_c_output.reserve(state_size_);
copy_tensor_to_vector(outputs[1], state_c_output);
state_h_output.clear();
state_h_output.reserve(state_size_);
copy_tensor_to_vector(outputs[2], state_h_output);
}
void
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
vector<Tensor> outputs;
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
// The feature computation graph is hardcoded to one audio length for now
const int n_windows = 1;
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
copy_tensor_to_vector(outputs[0], mfcc_output);
}

View File

@ -1,35 +0,0 @@
#ifndef TFMODELSTATE_H
#define TFMODELSTATE_H
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/memmapped_file_system.h"
#include "modelstate.h"
struct TFModelState : public ModelState
{
std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
std::unique_ptr<tensorflow::Session> session_;
tensorflow::GraphDef graph_def_;
TFModelState();
virtual ~TFModelState();
virtual int init(const char* model_path) override;
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override;
};
#endif // TFMODELSTATE_H