Remove full TF backend

This commit is contained in:
Reuben Morais 2021-07-27 21:27:40 +02:00
parent 2020f1b15a
commit 42ebbf9120
27 changed files with 247 additions and 1116 deletions

View File

@ -5,11 +5,8 @@ inputs:
description: "Target arch for loading script (host/armv7/aarch64)" description: "Target arch for loading script (host/armv7/aarch64)"
required: false required: false
default: "host" default: "host"
flavor:
description: "Build flavor"
required: true
runs: runs:
using: "composite" using: "composite"
steps: steps:
- run: ./ci_scripts/${{ inputs.arch }}-build.sh ${{ inputs.flavor }} - run: ./ci_scripts/${{ inputs.arch }}-build.sh
shell: bash shell: bash

View File

@ -1,9 +1,6 @@
name: "Python binding" name: "Python binding"
description: "Binding a python binding" description: "Binding a python binding"
inputs: inputs:
build_flavor:
description: "Python package name"
required: true
numpy_build: numpy_build:
description: "NumPy build dependecy" description: "NumPy build dependecy"
required: true required: true
@ -46,9 +43,6 @@ runs:
set -xe set -xe
PROJECT_NAME="stt" PROJECT_NAME="stt"
if [ "${{ inputs.build_flavor }}" = "tflite" ]; then
PROJECT_NAME="stt-tflite"
fi
OS=$(uname) OS=$(uname)
if [ "${OS}" = "Linux" ]; then if [ "${OS}" = "Linux" ]; then

View File

@ -4,9 +4,6 @@ inputs:
runtime: runtime:
description: "Runtime to use for running test" description: "Runtime to use for running test"
required: true required: true
build-flavor:
description: "Running against TF or TFLite"
required: true
model-kind: model-kind:
description: "Running against CI baked or production model" description: "Running against CI baked or production model"
required: true required: true
@ -22,10 +19,7 @@ runs:
- run: | - run: |
set -xe set -xe
build=""
if [ "${{ inputs.build-flavor }}" = "tflite" ]; then
build="_tflite" build="_tflite"
fi
model_kind="" model_kind=""
if [ "${{ inputs.model-kind }}" = "prod" ]; then if [ "${{ inputs.model-kind }}" = "prod" ]; then

File diff suppressed because it is too large Load Diff

View File

@ -55,23 +55,6 @@ maybe_install_xldd()
fi fi
} }
# Checks whether we run a patched version of bazel.
# Patching is required to dump computeKey() parameters to .ckd files
# See bazel.patch
# Return 0 (success exit code) on patched version, 1 on release version
is_patched_bazel()
{
bazel_version=$(bazel version | grep 'Build label:' | cut -d':' -f2)
bazel shutdown
if [ -z "${bazel_version}" ]; then
return 0;
else
return 1;
fi;
}
verify_bazel_rebuild() verify_bazel_rebuild()
{ {
bazel_explain_file="$1" bazel_explain_file="$1"

View File

@ -9,21 +9,15 @@ do_bazel_build()
cd ${DS_TFDIR} cd ${DS_TFDIR}
eval "export ${BAZEL_ENV_FLAGS}" eval "export ${BAZEL_ENV_FLAGS}"
if [ "${_opt_or_dbg}" = "opt" ]; then
if is_patched_bazel; then
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-tf.tar -T -
fi;
fi;
bazel ${BAZEL_OUTPUT_USER_ROOT} build \ bazel ${BAZEL_OUTPUT_USER_ROOT} build \
-s --explain bazel_monolithic.log --verbose_explanations --experimental_strict_action_env --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS} -s --explain bazel_explain.log --verbose_explanations \
--experimental_strict_action_env \
--workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" \
-c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
if [ "${_opt_or_dbg}" = "opt" ]; then if [ "${_opt_or_dbg}" = "opt" ]; then
if is_patched_bazel; then verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_explain.log"
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-ds.tar -T - fi
fi;
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_monolithic.log"
fi;
} }
shutdown_bazel() shutdown_bazel()

View File

@ -1,26 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_material
export PATH=${CI_TMP_DIR}/ds/:$PATH
check_versions
run_prod_inference_tests "${bitrate}"

View File

@ -1,24 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
export PATH=${CI_TMP_DIR}/ds/:$PATH
check_versions
run_all_inference_tests
run_multi_inference_tests
run_cpp_only_inference_tests
run_hotword_tests

View File

@ -1,20 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_material "${CI_TMP_DIR}/ds"
export PATH=${CI_TMP_DIR}/ds/:$PATH
check_versions
ensure_cuda_usage "$2"
run_basic_inference_tests

View File

@ -1,48 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_data
node --version
npm --version
symlink_electron
export_node_bin_path
which electron
which node
if [ "${OS}" = "Linux" ]; then
export DISPLAY=':99.0'
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
xvfb_process=$!
fi
node --version
stt --version
check_runtime_electronjs
run_electronjs_prod_inference_tests "${bitrate}"
if [ "${OS}" = "Linux" ]; then
sleep 1
sudo kill -9 ${xvfb_process} || true
fi

View File

@ -1,41 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
node --version
npm --version
symlink_electron
export_node_bin_path
which electron
which node
if [ "${OS}" = "Linux" ]; then
export DISPLAY=':99.0'
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
xvfb_process=$!
fi
node --version
stt --version
check_runtime_electronjs
run_electronjs_inference_tests
if [ "${OS}" = "Linux" ]; then
sleep 1
sudo kill -9 ${xvfb_process} || true
fi

View File

@ -2,8 +2,6 @@
set -xe set -xe
runtime=$1
source $(dirname "$0")/all-vars.sh source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/build-utils.sh source $(dirname "$0")/build-utils.sh
@ -15,10 +13,7 @@ BAZEL_TARGETS="
//native_client:generate_scorer_package //native_client:generate_scorer_package
" "
if [ "${runtime}" = "tflite" ]; then BAZEL_BUILD_FLAGS="${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
BAZEL_BUILD_TFLITE="--define=runtime=tflite"
fi;
BAZEL_BUILD_FLAGS="${BAZEL_BUILD_TFLITE} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
BAZEL_ENV_FLAGS="TF_NEED_CUDA=0" BAZEL_ENV_FLAGS="TF_NEED_CUDA=0"
SYSTEM_TARGET=host SYSTEM_TARGET=host

View File

@ -1,30 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_data
node --version
npm --version
export_node_bin_path
check_runtime_nodejs
run_prod_inference_tests "${bitrate}"
run_js_streaming_prod_inference_tests "${bitrate}"

View File

@ -1,25 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
node --version
npm --version
export_node_bin_path
check_runtime_nodejs
run_all_inference_tests
run_js_streaming_inference_tests
run_hotword_tests

View File

@ -30,10 +30,15 @@ package_native_client()
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib" win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib"
fi; fi;
if [ -f "${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib" ]; then
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so.if.lib"
fi;
${TAR} --verbose -cf - \ ${TAR} --verbose -cf - \
--transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \ --transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \
-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \ -C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \
-C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \ -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \
-C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so \
${win_lib} \ ${win_lib} \
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \ -C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
-C ${stt_dir}/ LICENSE \ -C ${stt_dir}/ LICENSE \
@ -94,5 +99,8 @@ package_libstt_as_zip()
echo "Please specify artifact name." echo "Please specify artifact name."
fi; fi;
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" ${tensorflow_dir}/bazel-bin/native_client/libstt.so ${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" \
${tensorflow_dir}/bazel-bin/native_client/libstt.so \
${tensorflow_dir}/bazel-bin/native_client/libkenlm.so \
${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so
} }

View File

@ -1,29 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
model_source=${STT_PROD_MODEL}
model_name=$(basename "${model_source}")
model_source_mmap=${STT_PROD_MODEL_MMAP}
model_name_mmap=$(basename "${model_source_mmap}")
download_model_prod
download_material
export_py_bin_path
which stt
stt --version
run_prod_inference_tests "${bitrate}"
run_prod_concurrent_stream_tests "${bitrate}"

View File

@ -1,21 +0,0 @@
#!/bin/bash
set -xe
source $(dirname "$0")/all-vars.sh
source $(dirname "$0")/all-utils.sh
source $(dirname "$0")/asserts.sh
bitrate=$1
set_ldc_sample_filename "${bitrate}"
download_data
export_py_bin_path
which stt
stt --version
run_all_inference_tests
run_hotword_tests

View File

@ -6,7 +6,7 @@ set -o pipefail
source $(dirname $0)/tf-vars.sh source $(dirname $0)/tf-vars.sh
pushd ${DS_ROOT_TASK}/tensorflow/ pushd ${DS_ROOT_TASK}/tensorflow/
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s --explain bazel_monolithic_tf.log --verbose_explanations --experimental_strict_action_env --config=monolithic" BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s"
# Start a bazel process to ensure reliability on Windows and avoid: # Start a bazel process to ensure reliability on Windows and avoid:
# FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing. # FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing.
@ -23,13 +23,10 @@ pushd ${DS_ROOT_TASK}/tensorflow/
case "$1" in case "$1" in
"--windows-cpu") "--windows-cpu")
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIBSTT} ${BUILD_TARGET_LITE_LIB} --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
;; ;;
"--linux-cpu"|"--darwin-cpu") "--linux-cpu"|"--darwin-cpu")
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_LITE_LIB} echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
;;
"--linux-cuda"|"--windows-cuda")
eval "export ${TF_CUDA_FLAGS}" && (echo "" | TF_NEED_CUDA=1 ./configure) && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_CUDA_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API}
;; ;;
"--linux-armv7") "--linux-armv7")
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB} echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}

View File

@ -6,26 +6,17 @@ source $(dirname $0)/tf-vars.sh
mkdir -p ${CI_ARTIFACTS_DIR} || true mkdir -p ${CI_ARTIFACTS_DIR} || true
cp ${DS_ROOT_TASK}/tensorflow/bazel_*.log ${CI_ARTIFACTS_DIR} || true
OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin" OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin"
for output_bin in \ for output_bin in \
tensorflow/lite/experimental/c/libtensorflowlite_c.so \ tensorflow/lite/libtensorflow.so \
tensorflow/tools/graph_transforms/transform_graph \ tensorflow/lite/libtensorflow.so.if.lib \
tensorflow/tools/graph_transforms/summarize_graph \ ;
tensorflow/tools/benchmark/benchmark_model \
tensorflow/contrib/util/convert_graphdef_memmapped_format \
tensorflow/lite/toco/toco;
do do
if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then
cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/ cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/
fi; fi;
done; done
if [ -f "${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model" ]; then
cp ${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model ${CI_ARTIFACTS_DIR}/lite_benchmark_model
fi
# It seems that bsdtar and gnutar are behaving a bit differently on the way # It seems that bsdtar and gnutar are behaving a bit differently on the way
# they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/ # they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/

View File

@ -5,12 +5,7 @@ set -ex
source $(dirname $0)/tf-vars.sh source $(dirname $0)/tf-vars.sh
install_android= install_android=
install_cuda=
case "$1" in case "$1" in
"--linux-cuda"|"--windows-cuda")
install_cuda=yes
;;
"--android-armv7"|"--android-arm64") "--android-armv7"|"--android-arm64")
install_android=yes install_android=yes
;; ;;
@ -29,11 +24,6 @@ download()
mkdir -p ${DS_ROOT_TASK}/dls || true mkdir -p ${DS_ROOT_TASK}/dls || true
download $BAZEL_URL $BAZEL_SHA256 download $BAZEL_URL $BAZEL_SHA256
if [ ! -z "${install_cuda}" ]; then
download $CUDA_URL $CUDA_SHA256
download $CUDNN_URL $CUDNN_SHA256
fi;
if [ ! -z "${install_android}" ]; then if [ ! -z "${install_android}" ]; then
download $ANDROID_NDK_URL $ANDROID_NDK_SHA256 download $ANDROID_NDK_URL $ANDROID_NDK_SHA256
download $ANDROID_SDK_URL $ANDROID_SDK_SHA256 download $ANDROID_SDK_URL $ANDROID_SDK_SHA256
@ -63,30 +53,6 @@ bazel version
bazel shutdown bazel shutdown
if [ ! -z "${install_cuda}" ]; then
# Install CUDA and CuDNN
mkdir -p ${DS_ROOT_TASK}/STT/CUDA/ || true
pushd ${DS_ROOT_TASK}
CUDA_FILE=`basename ${CUDA_URL}`
PERL5LIB=. sh ${DS_ROOT_TASK}/dls/${CUDA_FILE} --silent --override --toolkit --toolkitpath=${DS_ROOT_TASK}/STT/CUDA/ --defaultroot=${DS_ROOT_TASK}/STT/CUDA/
CUDNN_FILE=`basename ${CUDNN_URL}`
tar xvf ${DS_ROOT_TASK}/dls/${CUDNN_FILE} --strip-components=1 -C ${DS_ROOT_TASK}/STT/CUDA/
popd
LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH
# We might lack libcuda.so.1 symlink, let's fix as upstream does:
# https://github.com/tensorflow/tensorflow/pull/13811/files?diff=split#diff-2352449eb75e66016e97a591d3f0f43dR96
if [ ! -h "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1" ]; then
ln -s "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so" "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1"
fi;
else
echo "No CUDA/CuDNN to install"
fi
if [ ! -z "${install_android}" ]; then if [ ! -z "${install_android}" ]; then
mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true
ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}` ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}`

View File

@ -9,13 +9,6 @@ if [ "${OS}" = "Linux" ]; then
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh
BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed
CUDA_URL=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
CUDA_SHA256=e7c22dc21278eb1b82f34a60ad7640b41ad3943d929bebda3008b72536855d31
# From https://gitlab.com/nvidia/cuda/blob/centos7/10.1/devel/cudnn7/Dockerfile
CUDNN_URL=http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.1-linux-x64-v7.6.0.64.tgz
CUDNN_SHA256=e956c6f9222fcb867a10449cfc76dee5cfd7c7531021d95fe9586d7e043b57d7
ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip
ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd
@ -48,8 +41,6 @@ elif [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe
BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9 BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9
CUDA_INSTALL_DIRECTORY=$(cygpath 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1')
TAR=/usr/bin/tar.exe TAR=/usr/bin/tar.exe
elif [ "${OS}" = "Darwin" ]; then elif [ "${OS}" = "Darwin" ]; then
if [ -z "${CI_TASK_DIR}" -o -z "${CI_ARTIFACTS_DIR}" ]; then if [ -z "${CI_TASK_DIR}" -o -z "${CI_ARTIFACTS_DIR}" ]; then
@ -89,7 +80,6 @@ fi;
export PATH export PATH
if [ "${OS}" = "Linux" ]; then if [ "${OS}" = "Linux" ]; then
export LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/ export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/
export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/ export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/
fi; fi;
@ -160,27 +150,15 @@ export BAZEL_OUTPUT_USER_ROOT
NVCC_COMPUTE="3.5" NVCC_COMPUTE="3.5"
### Define build parameters/env variables that we will re-ues in sourcing scripts.
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_PATHS=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
else
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_PATHS=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
fi
BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV" BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV" BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV" BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV" BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_CUDA_FLAGS="--config=cuda"
if [ "${OS}" = "Linux" ]; then
# constexpr usage in tensorflow's absl dep fails badly because of gcc-5
# so let's skip that
BAZEL_CUDA_FLAGS="${BAZEL_CUDA_FLAGS} --copt=-DNO_CONSTEXPR_FOR_YOU=1"
fi
BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV" BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV" BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl --copt=-fvisibility=hidden" BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl"
fi fi
if [ "${OS}" = "Darwin" ]; then if [ "${OS}" = "Darwin" ]; then
@ -189,11 +167,5 @@ fi
### Define build targets that we will re-ues in sourcing scripts. ### Define build targets that we will re-ues in sourcing scripts.
BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc" BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc"
BUILD_TARGET_GRAPH_TRANSFORMS="//tensorflow/tools/graph_transforms:transform_graph" BUILD_TARGET_LITE_LIB="//tensorflow/lite:libtensorflowlite.so"
BUILD_TARGET_GRAPH_SUMMARIZE="//tensorflow/tools/graph_transforms:summarize_graph"
BUILD_TARGET_GRAPH_BENCHMARK="//tensorflow/tools/benchmark:benchmark_model"
#BUILD_TARGET_CONVERT_MMAP="//tensorflow/contrib/util:convert_graphdef_memmapped_format"
BUILD_TARGET_TOCO="//tensorflow/lite/toco:toco"
BUILD_TARGET_LITE_BENCHMARK="//tensorflow/lite/tools/benchmark:benchmark_model"
BUILD_TARGET_LITE_LIB="//tensorflow/lite/c:libtensorflowlite_c.so"
BUILD_TARGET_LIBSTT="//native_client:libstt.so" BUILD_TARGET_LIBSTT="//native_client:libstt.so"

View File

@ -1,22 +1,9 @@
# Description: Coqui STT native client library. # Description: Coqui STT native client library.
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object", "tf_copts", "lrt_if_needed") load("@org_tensorflow//tensorflow:tensorflow.bzl", "lrt_if_needed")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
load(
"@org_tensorflow//tensorflow/lite:build_def.bzl",
"tflite_copts",
"tflite_linkopts",
)
config_setting(
name = "tflite",
define_values = {
"runtime": "tflite",
},
)
config_setting( config_setting(
name = "rpi3", name = "rpi3",
@ -85,7 +72,7 @@ LINUX_LINKOPTS = [
"-Wl,-export-dynamic", "-Wl,-export-dynamic",
] ]
tf_cc_shared_object( cc_binary(
name = "libkenlm.so", name = "libkenlm.so",
srcs = glob([ srcs = glob([
"kenlm/lm/*.hh", "kenlm/lm/*.hh",
@ -107,8 +94,19 @@ tf_cc_shared_object(
}), }),
defines = ["KENLM_MAX_ORDER=6"], defines = ["KENLM_MAX_ORDER=6"],
includes = ["kenlm"], includes = ["kenlm"],
framework_so = [], linkshared = 1,
linkopts = [], linkopts = select({
"//tensorflow:ios": [
"-Wl,-install_name,@rpath/libkenlm.so",
],
"//tensorflow:macos": [
"-Wl,-install_name,@rpath/libkenlm.so",
],
"//tensorflow:windows": [],
"//conditions:default": [
"-Wl,-soname,libkenlm.so",
],
}),
) )
cc_library( cc_library(
@ -132,6 +130,20 @@ cc_library(
copts = ["-fexceptions"], copts = ["-fexceptions"],
) )
cc_library(
name="tflite",
hdrs = [
"//tensorflow/lite:model.h",
"//tensorflow/lite/kernels:register.h",
"//tensorflow/lite/tools/evaluation:utils.h",
],
srcs = [
"//tensorflow/lite:libtensorflowlite.so",
],
includes = ["tensorflow"],
deps = ["//tensorflow/lite:libtensorflowlite.so"],
)
cc_library( cc_library(
name = "coqui_stt_bundle", name = "coqui_stt_bundle",
srcs = [ srcs = [
@ -142,17 +154,10 @@ cc_library(
"modelstate.h", "modelstate.h",
"workspace_status.cc", "workspace_status.cc",
"workspace_status.h", "workspace_status.h",
] + select({
"//native_client:tflite": [
"tflitemodelstate.h", "tflitemodelstate.h",
"tflitemodelstate.cc", "tflitemodelstate.cc",
], ] + DECODER_SOURCES,
"//conditions:default": [ copts = select({
"tfmodelstate.h",
"tfmodelstate.cc",
],
}) + DECODER_SOURCES,
copts = tf_copts(allow_exceptions=True) + select({
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default # -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
"//tensorflow:windows": ["/w"], "//tensorflow:windows": ["/w"],
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself, # -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
@ -161,9 +166,6 @@ cc_library(
"-Wno-sign-compare", "-Wno-sign-compare",
"-fvisibility=hidden", "-fvisibility=hidden",
], ],
}) + select({
"//native_client:tflite": ["-DUSE_TFLITE"],
"//conditions:default": ["-UUSE_TFLITE"],
}), }),
linkopts = lrt_if_needed() + select({ linkopts = lrt_if_needed() + select({
"//tensorflow:macos": [], "//tensorflow:macos": [],
@ -174,64 +176,32 @@ cc_library(
# Bazel is has too strong opinions about static linking, so it's # Bazel is has too strong opinions about static linking, so it's
# near impossible to get it to link a DLL against another DLL on Windows. # near impossible to get it to link a DLL against another DLL on Windows.
# We simply force the linker option manually here as a hacky fix. # We simply force the linker option manually here as a hacky fix.
"//tensorflow:windows": ["bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib"], "//tensorflow:windows": [
"bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib",
"bazel-out/x64_windows-opt/bin/tensorflow/lite/libtensorflowlite.so.if.lib",
],
"//conditions:default": [], "//conditions:default": [],
}) + tflite_linkopts() + DECODER_LINKOPTS, }) + DECODER_LINKOPTS,
includes = DECODER_INCLUDES, includes = DECODER_INCLUDES,
deps = select({ deps = [":kenlm", ":tflite"],
"//native_client:tflite": [
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/tools/evaluation:utils",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//third_party/eigen3",
#"//tensorflow/core:all_kernels",
### => Trying to be more fine-grained
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
### CPU only build, libstt.so file size reduced by ~50%
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:bias_op", # BiasAdd
"//tensorflow/core/kernels:cast_op", # Cast
"//tensorflow/core/kernels:concat_op", # ConcatV2
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
"//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
"//tensorflow/core/kernels:matmul_op", # MatMul
"//tensorflow/core/kernels:reduction_ops", # Max
"//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:no_op", # NoOp
"//tensorflow/core/kernels:pack_op", # Pack
"//tensorflow/core/kernels:sequence_ops", # Range
"//tensorflow/core/kernels:relu_op", # Relu
"//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:softmax_op", # Softmax
"//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:rnn_ops", # BlockLSTM
# And we also need the op libs for these ops used in the model:
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
"//tensorflow/core:rnn_ops_op_lib", # BlockLSTM
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
"//tensorflow/core:no_op_op_lib", # NoOp
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
],
}) + if_cuda([
"//tensorflow/core:core",
]) + [":kenlm"],
) )
tf_cc_shared_object( cc_binary(
name = "libstt.so", name = "libstt.so",
deps = [":coqui_stt_bundle"], deps = [":coqui_stt_bundle"],
linkshared = 1,
linkopts = select({
"//tensorflow:ios": [
"-Wl,-install_name,@rpath/libstt.so",
],
"//tensorflow:macos": [
"-Wl,-install_name,@rpath/libstt.so",
],
"//tensorflow:windows": [],
"//conditions:default": [
"-Wl,-soname,libstt.so",
],
}),
) )
ios_static_framework( ios_static_framework(

View File

@ -20,8 +20,8 @@ endif
STT_BIN := stt$(PLATFORM_EXE_SUFFIX) STT_BIN := stt$(PLATFORM_EXE_SUFFIX)
CFLAGS_STT := -std=c++11 -o $(STT_BIN) CFLAGS_STT := -std=c++11 -o $(STT_BIN)
LINK_STT := -lstt -lkenlm LINK_STT := -lstt -lkenlm -ltensorflowlite
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client -L${TFDIR}/bazel-bin/tensorflow/lite
ifeq ($(TARGET),host) ifeq ($(TARGET),host)
TOOLCHAIN := TOOLCHAIN :=
@ -61,7 +61,7 @@ TOOL_CC := cl.exe
TOOL_CXX := cl.exe TOOL_CXX := cl.exe
TOOL_LD := link.exe TOOL_LD := link.exe
TOOL_LIBEXE := lib.exe TOOL_LIBEXE := lib.exe
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib") LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib")
LINK_PATH_STT := LINK_PATH_STT :=
CFLAGS_STT := -nologo -Fe$(STT_BIN) CFLAGS_STT := -nologo -Fe$(STT_BIN)
SOX_CFLAGS := SOX_CFLAGS :=
@ -185,7 +185,7 @@ define copy_missing_libs
new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \ new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \
missing_libs="$$missing_libs $$new_missing"; \ missing_libs="$$missing_libs $$new_missing"; \
elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \ elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \
missing_libs="libstt.so libkenlm.so"; \ missing_libs="libstt.so libkenlm.so libtensorflowlite.so"; \
else \ else \
missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \ missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \
fi; \ fi; \

View File

@ -27,12 +27,14 @@
"libraries": [ "libraries": [
"../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib", "../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib",
"../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib", "../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib",
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib",
], ],
}, },
{ {
"libraries": [ "libraries": [
"../../../tensorflow/bazel-bin/native_client/libstt.so", "../../../tensorflow/bazel-bin/native_client/libstt.so",
"../../../tensorflow/bazel-bin/native_client/libkenlm.so", "../../../tensorflow/bazel-bin/native_client/libkenlm.so",
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so",
], ],
}, },
], ],

View File

@ -14,13 +14,7 @@
#include "modelstate.h" #include "modelstate.h"
#include "workspace_status.h" #include "workspace_status.h"
#ifndef USE_TFLITE
#include "tfmodelstate.h"
#else
#include "tflitemodelstate.h" #include "tflitemodelstate.h"
#endif // USE_TFLITE
#include "ctcdecode/ctc_beam_search_decoder.h" #include "ctcdecode/ctc_beam_search_decoder.h"
#ifdef __ANDROID__ #ifdef __ANDROID__
@ -282,13 +276,7 @@ STT_CreateModel(const char* aModelPath,
return STT_ERR_NO_MODEL; return STT_ERR_NO_MODEL;
} }
std::unique_ptr<ModelState> model( std::unique_ptr<ModelState> model(new TFLiteModelState());
#ifndef USE_TFLITE
new TFModelState()
#else
new TFLiteModelState()
#endif
);
if (!model) { if (!model) {
std::cerr << "Could not allocate model state." << std::endl; std::cerr << "Could not allocate model state." << std::endl;

View File

@ -1,263 +0,0 @@
#include "tfmodelstate.h"
#include "workspace_status.h"
using namespace tensorflow;
using std::vector;
TFModelState::TFModelState()
: ModelState()
, mmap_env_(nullptr)
, session_(nullptr)
{
}
TFModelState::~TFModelState()
{
if (session_) {
Status status = session_->Close();
if (!status.ok()) {
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
}
}
}
int
TFModelState::init(const char* model_path)
{
int err = ModelState::init(model_path);
if (err != STT_ERR_OK) {
return err;
}
Status status;
SessionOptions options;
mmap_env_.reset(new MemmappedEnv(Env::Default()));
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
if (!is_mmap) {
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
} else {
status = mmap_env_->InitializeFromFile(model_path);
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_INIT_MMAP;
}
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(::OptimizerOptions::L0);
options.env = mmap_env_.get();
}
Session* session;
status = NewSession(options, &session);
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_INIT_SESS;
}
session_.reset(session);
if (is_mmap) {
status = ReadBinaryProto(mmap_env_.get(),
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&graph_def_);
} else {
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
}
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_READ_PROTOBUF;
}
status = session_->Create(graph_def_);
if (!status.ok()) {
std::cerr << status << std::endl;
return STT_ERR_FAIL_CREATE_SESS;
}
std::vector<tensorflow::Tensor> version_output;
status = session_->Run({}, {
"metadata_version"
}, {}, &version_output);
if (!status.ok()) {
std::cerr << "Unable to fetch graph version: " << status << std::endl;
return STT_ERR_MODEL_INCOMPATIBLE;
}
int graph_version = version_output[0].scalar<int>()();
if (graph_version < ds_graph_version()) {
std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
<< ds_graph_version() << "). See "
<< "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
<< "for more information" << std::endl;
return STT_ERR_MODEL_INCOMPATIBLE;
}
std::vector<tensorflow::Tensor> metadata_outputs;
status = session_->Run({}, {
"metadata_sample_rate",
"metadata_feature_win_len",
"metadata_feature_win_step",
"metadata_beam_width",
"metadata_alphabet",
}, {}, &metadata_outputs);
if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl;
return STT_ERR_MODEL_INCOMPATIBLE;
}
sample_rate_ = metadata_outputs[0].scalar<int>()();
int win_len_ms = metadata_outputs[1].scalar<int>()();
int win_step_ms = metadata_outputs[2].scalar<int>()();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
int beam_width = metadata_outputs[3].scalar<int>()();
beam_width_ = (unsigned int)(beam_width);
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
if (err != 0) {
return STT_ERR_INVALID_ALPHABET;
}
assert(sample_rate_ > 0);
assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0);
assert(beam_width_ > 0);
assert(alphabet_.GetSize() > 0);
for (int i = 0; i < graph_def_.node_size(); ++i) {
NodeDef node = graph_def_.node(i);
if (node.name() == "input_node") {
const auto& shape = node.attr().at("shape").shape();
n_steps_ = shape.dim(1).size();
n_context_ = (shape.dim(2).size()-1)/2;
n_features_ = shape.dim(3).size();
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
} else if (node.name() == "previous_state_c") {
const auto& shape = node.attr().at("shape").shape();
state_size_ = shape.dim(1).size();
} else if (node.name() == "logits_shape") {
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
continue;
}
int final_dim_size = logits_shape.vec<int>()(2) - 1;
if (final_dim_size != alphabet_.GetSize()) {
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
<< "has size " << alphabet_.GetSize()
<< ", but model has " << final_dim_size
<< " classes in its output. Make sure you're passing an alphabet "
<< "file with the same size as the one used for training."
<< std::endl;
return STT_ERR_INVALID_ALPHABET;
}
}
}
if (n_context_ == -1 || n_features_ == -1) {
std::cerr << "Error: Could not infer input shape from model file. "
<< "Make sure input_node is a 4D tensor with shape "
<< "[batch_size=1, time, window_size, n_features]."
<< std::endl;
return STT_ERR_INVALID_SHAPE;
}
return STT_ERR_OK;
}
Tensor
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
{
Tensor ret(DT_FLOAT, shape);
auto ret_mapped = ret.flat<float>();
int i;
for (i = 0; i < vec.size(); ++i) {
ret_mapped(i) = vec[i];
}
for (; i < shape.num_elements(); ++i) {
ret_mapped(i) = 0.f;
}
return ret;
}
void
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
{
auto tensor_mapped = tensor.flat<float>();
if (num_elements == -1) {
num_elements = tensor.shape().num_elements();
}
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor_mapped(i));
}
}
void
TFModelState::infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor input_lengths(DT_INT32, TensorShape({1}));
input_lengths.scalar<int>()() = n_frames;
vector<Tensor> outputs;
Status status = session_->Run(
{
{"input_node", input},
{"input_lengths", input_lengths},
{"previous_state_c", previous_state_c_t},
{"previous_state_h", previous_state_h_t}
},
{"logits", "new_state_c", "new_state_h"},
{},
&outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
state_c_output.clear();
state_c_output.reserve(state_size_);
copy_tensor_to_vector(outputs[1], state_c_output);
state_h_output.clear();
state_h_output.reserve(state_size_);
copy_tensor_to_vector(outputs[2], state_h_output);
}
void
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
vector<Tensor> outputs;
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
// The feature computation graph is hardcoded to one audio length for now
const int n_windows = 1;
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
copy_tensor_to_vector(outputs[0], mfcc_output);
}

View File

@ -1,35 +0,0 @@
#ifndef TFMODELSTATE_H
#define TFMODELSTATE_H
#include <vector>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/memmapped_file_system.h"
#include "modelstate.h"
struct TFModelState : public ModelState
{
std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
std::unique_ptr<tensorflow::Session> session_;
tensorflow::GraphDef graph_def_;
TFModelState();
virtual ~TFModelState();
virtual int init(const char* model_path) override;
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override;
};
#endif // TFMODELSTATE_H