Remove full TF backend
This commit is contained in:
parent
2020f1b15a
commit
42ebbf9120
|
@ -5,11 +5,8 @@ inputs:
|
|||
description: "Target arch for loading script (host/armv7/aarch64)"
|
||||
required: false
|
||||
default: "host"
|
||||
flavor:
|
||||
description: "Build flavor"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- run: ./ci_scripts/${{ inputs.arch }}-build.sh ${{ inputs.flavor }}
|
||||
- run: ./ci_scripts/${{ inputs.arch }}-build.sh
|
||||
shell: bash
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
name: "Python binding"
|
||||
description: "Binding a python binding"
|
||||
inputs:
|
||||
build_flavor:
|
||||
description: "Python package name"
|
||||
required: true
|
||||
numpy_build:
|
||||
description: "NumPy build dependecy"
|
||||
required: true
|
||||
|
@ -46,9 +43,6 @@ runs:
|
|||
set -xe
|
||||
|
||||
PROJECT_NAME="stt"
|
||||
if [ "${{ inputs.build_flavor }}" = "tflite" ]; then
|
||||
PROJECT_NAME="stt-tflite"
|
||||
fi
|
||||
|
||||
OS=$(uname)
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
|
|
|
@ -4,9 +4,6 @@ inputs:
|
|||
runtime:
|
||||
description: "Runtime to use for running test"
|
||||
required: true
|
||||
build-flavor:
|
||||
description: "Running against TF or TFLite"
|
||||
required: true
|
||||
model-kind:
|
||||
description: "Running against CI baked or production model"
|
||||
required: true
|
||||
|
@ -22,10 +19,7 @@ runs:
|
|||
- run: |
|
||||
set -xe
|
||||
|
||||
build=""
|
||||
if [ "${{ inputs.build-flavor }}" = "tflite" ]; then
|
||||
build="_tflite"
|
||||
fi
|
||||
build="_tflite"
|
||||
|
||||
model_kind=""
|
||||
if [ "${{ inputs.model-kind }}" = "prod" ]; then
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -55,23 +55,6 @@ maybe_install_xldd()
|
|||
fi
|
||||
}
|
||||
|
||||
# Checks whether we run a patched version of bazel.
|
||||
# Patching is required to dump computeKey() parameters to .ckd files
|
||||
# See bazel.patch
|
||||
# Return 0 (success exit code) on patched version, 1 on release version
|
||||
is_patched_bazel()
|
||||
{
|
||||
bazel_version=$(bazel version | grep 'Build label:' | cut -d':' -f2)
|
||||
|
||||
bazel shutdown
|
||||
|
||||
if [ -z "${bazel_version}" ]; then
|
||||
return 0;
|
||||
else
|
||||
return 1;
|
||||
fi;
|
||||
}
|
||||
|
||||
verify_bazel_rebuild()
|
||||
{
|
||||
bazel_explain_file="$1"
|
||||
|
|
|
@ -9,21 +9,15 @@ do_bazel_build()
|
|||
cd ${DS_TFDIR}
|
||||
eval "export ${BAZEL_ENV_FLAGS}"
|
||||
|
||||
if [ "${_opt_or_dbg}" = "opt" ]; then
|
||||
if is_patched_bazel; then
|
||||
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-tf.tar -T -
|
||||
fi;
|
||||
fi;
|
||||
|
||||
bazel ${BAZEL_OUTPUT_USER_ROOT} build \
|
||||
-s --explain bazel_monolithic.log --verbose_explanations --experimental_strict_action_env --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
|
||||
-s --explain bazel_explain.log --verbose_explanations \
|
||||
--experimental_strict_action_env \
|
||||
--workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" \
|
||||
-c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
|
||||
|
||||
if [ "${_opt_or_dbg}" = "opt" ]; then
|
||||
if is_patched_bazel; then
|
||||
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-ds.tar -T -
|
||||
fi;
|
||||
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_monolithic.log"
|
||||
fi;
|
||||
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_explain.log"
|
||||
fi
|
||||
}
|
||||
|
||||
shutdown_bazel()
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_material
|
||||
|
||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
||||
|
||||
check_versions
|
||||
|
||||
run_prod_inference_tests "${bitrate}"
|
|
@ -1,24 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
||||
|
||||
check_versions
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_multi_inference_tests
|
||||
|
||||
run_cpp_only_inference_tests
|
||||
|
||||
run_hotword_tests
|
|
@ -1,20 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_material "${CI_TMP_DIR}/ds"
|
||||
|
||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
||||
|
||||
check_versions
|
||||
|
||||
ensure_cuda_usage "$2"
|
||||
|
||||
run_basic_inference_tests
|
|
@ -1,48 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
symlink_electron
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
which electron
|
||||
which node
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export DISPLAY=':99.0'
|
||||
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
|
||||
xvfb_process=$!
|
||||
fi
|
||||
|
||||
node --version
|
||||
|
||||
stt --version
|
||||
|
||||
check_runtime_electronjs
|
||||
|
||||
run_electronjs_prod_inference_tests "${bitrate}"
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
sleep 1
|
||||
sudo kill -9 ${xvfb_process} || true
|
||||
fi
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
symlink_electron
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
which electron
|
||||
which node
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export DISPLAY=':99.0'
|
||||
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
|
||||
xvfb_process=$!
|
||||
fi
|
||||
|
||||
node --version
|
||||
|
||||
stt --version
|
||||
|
||||
check_runtime_electronjs
|
||||
|
||||
run_electronjs_inference_tests
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
sleep 1
|
||||
sudo kill -9 ${xvfb_process} || true
|
||||
fi
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
set -xe
|
||||
|
||||
runtime=$1
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/build-utils.sh
|
||||
|
@ -15,10 +13,7 @@ BAZEL_TARGETS="
|
|||
//native_client:generate_scorer_package
|
||||
"
|
||||
|
||||
if [ "${runtime}" = "tflite" ]; then
|
||||
BAZEL_BUILD_TFLITE="--define=runtime=tflite"
|
||||
fi;
|
||||
BAZEL_BUILD_FLAGS="${BAZEL_BUILD_TFLITE} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||
BAZEL_BUILD_FLAGS="${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||
|
||||
BAZEL_ENV_FLAGS="TF_NEED_CUDA=0"
|
||||
SYSTEM_TARGET=host
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
check_runtime_nodejs
|
||||
|
||||
run_prod_inference_tests "${bitrate}"
|
||||
|
||||
run_js_streaming_prod_inference_tests "${bitrate}"
|
|
@ -1,25 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
node --version
|
||||
npm --version
|
||||
|
||||
export_node_bin_path
|
||||
|
||||
check_runtime_nodejs
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_js_streaming_inference_tests
|
||||
|
||||
run_hotword_tests
|
|
@ -30,10 +30,15 @@ package_native_client()
|
|||
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib"
|
||||
fi;
|
||||
|
||||
if [ -f "${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib" ]; then
|
||||
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so.if.lib"
|
||||
fi;
|
||||
|
||||
${TAR} --verbose -cf - \
|
||||
--transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \
|
||||
-C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so \
|
||||
${win_lib} \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
|
||||
-C ${stt_dir}/ LICENSE \
|
||||
|
@ -94,5 +99,8 @@ package_libstt_as_zip()
|
|||
echo "Please specify artifact name."
|
||||
fi;
|
||||
|
||||
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" ${tensorflow_dir}/bazel-bin/native_client/libstt.so
|
||||
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" \
|
||||
${tensorflow_dir}/bazel-bin/native_client/libstt.so \
|
||||
${tensorflow_dir}/bazel-bin/native_client/libkenlm.so \
|
||||
${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so
|
||||
}
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
model_source=${STT_PROD_MODEL}
|
||||
model_name=$(basename "${model_source}")
|
||||
|
||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
||||
model_name_mmap=$(basename "${model_source_mmap}")
|
||||
|
||||
download_model_prod
|
||||
|
||||
download_material
|
||||
|
||||
export_py_bin_path
|
||||
|
||||
which stt
|
||||
stt --version
|
||||
|
||||
run_prod_inference_tests "${bitrate}"
|
||||
|
||||
run_prod_concurrent_stream_tests "${bitrate}"
|
|
@ -1,21 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/all-vars.sh
|
||||
source $(dirname "$0")/all-utils.sh
|
||||
source $(dirname "$0")/asserts.sh
|
||||
|
||||
bitrate=$1
|
||||
set_ldc_sample_filename "${bitrate}"
|
||||
|
||||
download_data
|
||||
|
||||
export_py_bin_path
|
||||
|
||||
which stt
|
||||
stt --version
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_hotword_tests
|
|
@ -6,7 +6,7 @@ set -o pipefail
|
|||
source $(dirname $0)/tf-vars.sh
|
||||
|
||||
pushd ${DS_ROOT_TASK}/tensorflow/
|
||||
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s --explain bazel_monolithic_tf.log --verbose_explanations --experimental_strict_action_env --config=monolithic"
|
||||
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s"
|
||||
|
||||
# Start a bazel process to ensure reliability on Windows and avoid:
|
||||
# FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing.
|
||||
|
@ -18,18 +18,15 @@ pushd ${DS_ROOT_TASK}/tensorflow/
|
|||
MAYBE_DEBUG=$2
|
||||
OPT_OR_DBG="-c opt"
|
||||
if [ "${MAYBE_DEBUG}" = "dbg" ]; then
|
||||
OPT_OR_DBG="-c dbg"
|
||||
OPT_OR_DBG="-c dbg"
|
||||
fi;
|
||||
|
||||
case "$1" in
|
||||
"--windows-cpu")
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIBSTT} ${BUILD_TARGET_LITE_LIB} --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh"
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
"--linux-cpu"|"--darwin-cpu")
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
"--linux-cuda"|"--windows-cuda")
|
||||
eval "export ${TF_CUDA_FLAGS}" && (echo "" | TF_NEED_CUDA=1 ./configure) && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_CUDA_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API}
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
;;
|
||||
"--linux-armv7")
|
||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||
|
|
|
@ -6,26 +6,17 @@ source $(dirname $0)/tf-vars.sh
|
|||
|
||||
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
||||
|
||||
cp ${DS_ROOT_TASK}/tensorflow/bazel_*.log ${CI_ARTIFACTS_DIR} || true
|
||||
|
||||
OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin"
|
||||
|
||||
for output_bin in \
|
||||
tensorflow/lite/experimental/c/libtensorflowlite_c.so \
|
||||
tensorflow/tools/graph_transforms/transform_graph \
|
||||
tensorflow/tools/graph_transforms/summarize_graph \
|
||||
tensorflow/tools/benchmark/benchmark_model \
|
||||
tensorflow/contrib/util/convert_graphdef_memmapped_format \
|
||||
tensorflow/lite/toco/toco;
|
||||
for output_bin in \
|
||||
tensorflow/lite/libtensorflow.so \
|
||||
tensorflow/lite/libtensorflow.so.if.lib \
|
||||
;
|
||||
do
|
||||
if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then
|
||||
cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/
|
||||
fi;
|
||||
done;
|
||||
|
||||
if [ -f "${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model" ]; then
|
||||
cp ${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model ${CI_ARTIFACTS_DIR}/lite_benchmark_model
|
||||
fi
|
||||
done
|
||||
|
||||
# It seems that bsdtar and gnutar are behaving a bit differently on the way
|
||||
# they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/
|
||||
|
|
|
@ -5,12 +5,7 @@ set -ex
|
|||
source $(dirname $0)/tf-vars.sh
|
||||
|
||||
install_android=
|
||||
install_cuda=
|
||||
case "$1" in
|
||||
"--linux-cuda"|"--windows-cuda")
|
||||
install_cuda=yes
|
||||
;;
|
||||
|
||||
"--android-armv7"|"--android-arm64")
|
||||
install_android=yes
|
||||
;;
|
||||
|
@ -29,11 +24,6 @@ download()
|
|||
mkdir -p ${DS_ROOT_TASK}/dls || true
|
||||
download $BAZEL_URL $BAZEL_SHA256
|
||||
|
||||
if [ ! -z "${install_cuda}" ]; then
|
||||
download $CUDA_URL $CUDA_SHA256
|
||||
download $CUDNN_URL $CUDNN_SHA256
|
||||
fi;
|
||||
|
||||
if [ ! -z "${install_android}" ]; then
|
||||
download $ANDROID_NDK_URL $ANDROID_NDK_SHA256
|
||||
download $ANDROID_SDK_URL $ANDROID_SDK_SHA256
|
||||
|
@ -63,30 +53,6 @@ bazel version
|
|||
|
||||
bazel shutdown
|
||||
|
||||
if [ ! -z "${install_cuda}" ]; then
|
||||
# Install CUDA and CuDNN
|
||||
mkdir -p ${DS_ROOT_TASK}/STT/CUDA/ || true
|
||||
pushd ${DS_ROOT_TASK}
|
||||
CUDA_FILE=`basename ${CUDA_URL}`
|
||||
PERL5LIB=. sh ${DS_ROOT_TASK}/dls/${CUDA_FILE} --silent --override --toolkit --toolkitpath=${DS_ROOT_TASK}/STT/CUDA/ --defaultroot=${DS_ROOT_TASK}/STT/CUDA/
|
||||
|
||||
CUDNN_FILE=`basename ${CUDNN_URL}`
|
||||
tar xvf ${DS_ROOT_TASK}/dls/${CUDNN_FILE} --strip-components=1 -C ${DS_ROOT_TASK}/STT/CUDA/
|
||||
popd
|
||||
|
||||
LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
|
||||
export LD_LIBRARY_PATH
|
||||
|
||||
# We might lack libcuda.so.1 symlink, let's fix as upstream does:
|
||||
# https://github.com/tensorflow/tensorflow/pull/13811/files?diff=split#diff-2352449eb75e66016e97a591d3f0f43dR96
|
||||
if [ ! -h "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1" ]; then
|
||||
ln -s "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so" "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1"
|
||||
fi;
|
||||
|
||||
else
|
||||
echo "No CUDA/CuDNN to install"
|
||||
fi
|
||||
|
||||
if [ ! -z "${install_android}" ]; then
|
||||
mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true
|
||||
ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}`
|
||||
|
|
|
@ -9,13 +9,6 @@ if [ "${OS}" = "Linux" ]; then
|
|||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh
|
||||
BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed
|
||||
|
||||
CUDA_URL=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
|
||||
CUDA_SHA256=e7c22dc21278eb1b82f34a60ad7640b41ad3943d929bebda3008b72536855d31
|
||||
|
||||
# From https://gitlab.com/nvidia/cuda/blob/centos7/10.1/devel/cudnn7/Dockerfile
|
||||
CUDNN_URL=http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.1-linux-x64-v7.6.0.64.tgz
|
||||
CUDNN_SHA256=e956c6f9222fcb867a10449cfc76dee5cfd7c7531021d95fe9586d7e043b57d7
|
||||
|
||||
ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip
|
||||
ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd
|
||||
|
||||
|
@ -48,8 +41,6 @@ elif [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
|||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe
|
||||
BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9
|
||||
|
||||
CUDA_INSTALL_DIRECTORY=$(cygpath 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1')
|
||||
|
||||
TAR=/usr/bin/tar.exe
|
||||
elif [ "${OS}" = "Darwin" ]; then
|
||||
if [ -z "${CI_TASK_DIR}" -o -z "${CI_ARTIFACTS_DIR}" ]; then
|
||||
|
@ -89,7 +80,6 @@ fi;
|
|||
export PATH
|
||||
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
export LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
|
||||
export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/
|
||||
export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/
|
||||
fi;
|
||||
|
@ -160,27 +150,15 @@ export BAZEL_OUTPUT_USER_ROOT
|
|||
|
||||
NVCC_COMPUTE="3.5"
|
||||
|
||||
### Define build parameters/env variables that we will re-ues in sourcing scripts.
|
||||
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
||||
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_PATHS=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
|
||||
else
|
||||
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_PATHS=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
|
||||
fi
|
||||
BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_CUDA_FLAGS="--config=cuda"
|
||||
if [ "${OS}" = "Linux" ]; then
|
||||
# constexpr usage in tensorflow's absl dep fails badly because of gcc-5
|
||||
# so let's skip that
|
||||
BAZEL_CUDA_FLAGS="${BAZEL_CUDA_FLAGS} --copt=-DNO_CONSTEXPR_FOR_YOU=1"
|
||||
fi
|
||||
BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||
|
||||
if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then
|
||||
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl --copt=-fvisibility=hidden"
|
||||
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl"
|
||||
fi
|
||||
|
||||
if [ "${OS}" = "Darwin" ]; then
|
||||
|
@ -189,11 +167,5 @@ fi
|
|||
|
||||
### Define build targets that we will re-ues in sourcing scripts.
|
||||
BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc"
|
||||
BUILD_TARGET_GRAPH_TRANSFORMS="//tensorflow/tools/graph_transforms:transform_graph"
|
||||
BUILD_TARGET_GRAPH_SUMMARIZE="//tensorflow/tools/graph_transforms:summarize_graph"
|
||||
BUILD_TARGET_GRAPH_BENCHMARK="//tensorflow/tools/benchmark:benchmark_model"
|
||||
#BUILD_TARGET_CONVERT_MMAP="//tensorflow/contrib/util:convert_graphdef_memmapped_format"
|
||||
BUILD_TARGET_TOCO="//tensorflow/lite/toco:toco"
|
||||
BUILD_TARGET_LITE_BENCHMARK="//tensorflow/lite/tools/benchmark:benchmark_model"
|
||||
BUILD_TARGET_LITE_LIB="//tensorflow/lite/c:libtensorflowlite_c.so"
|
||||
BUILD_TARGET_LITE_LIB="//tensorflow/lite:libtensorflowlite.so"
|
||||
BUILD_TARGET_LIBSTT="//native_client:libstt.so"
|
||||
|
|
|
@ -1,22 +1,9 @@
|
|||
# Description: Coqui STT native client library.
|
||||
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object", "tf_copts", "lrt_if_needed")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "lrt_if_needed")
|
||||
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
|
||||
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
||||
"tflite_copts",
|
||||
"tflite_linkopts",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "tflite",
|
||||
define_values = {
|
||||
"runtime": "tflite",
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "rpi3",
|
||||
|
@ -85,7 +72,7 @@ LINUX_LINKOPTS = [
|
|||
"-Wl,-export-dynamic",
|
||||
]
|
||||
|
||||
tf_cc_shared_object(
|
||||
cc_binary(
|
||||
name = "libkenlm.so",
|
||||
srcs = glob([
|
||||
"kenlm/lm/*.hh",
|
||||
|
@ -107,8 +94,19 @@ tf_cc_shared_object(
|
|||
}),
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = ["kenlm"],
|
||||
framework_so = [],
|
||||
linkopts = [],
|
||||
linkshared = 1,
|
||||
linkopts = select({
|
||||
"//tensorflow:ios": [
|
||||
"-Wl,-install_name,@rpath/libkenlm.so",
|
||||
],
|
||||
"//tensorflow:macos": [
|
||||
"-Wl,-install_name,@rpath/libkenlm.so",
|
||||
],
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,-soname,libkenlm.so",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -132,6 +130,20 @@ cc_library(
|
|||
copts = ["-fexceptions"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name="tflite",
|
||||
hdrs = [
|
||||
"//tensorflow/lite:model.h",
|
||||
"//tensorflow/lite/kernels:register.h",
|
||||
"//tensorflow/lite/tools/evaluation:utils.h",
|
||||
],
|
||||
srcs = [
|
||||
"//tensorflow/lite:libtensorflowlite.so",
|
||||
],
|
||||
includes = ["tensorflow"],
|
||||
deps = ["//tensorflow/lite:libtensorflowlite.so"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "coqui_stt_bundle",
|
||||
srcs = [
|
||||
|
@ -142,17 +154,10 @@ cc_library(
|
|||
"modelstate.h",
|
||||
"workspace_status.cc",
|
||||
"workspace_status.h",
|
||||
] + select({
|
||||
"//native_client:tflite": [
|
||||
"tflitemodelstate.h",
|
||||
"tflitemodelstate.cc",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"tfmodelstate.h",
|
||||
"tfmodelstate.cc",
|
||||
],
|
||||
}) + DECODER_SOURCES,
|
||||
copts = tf_copts(allow_exceptions=True) + select({
|
||||
"tflitemodelstate.h",
|
||||
"tflitemodelstate.cc",
|
||||
] + DECODER_SOURCES,
|
||||
copts = select({
|
||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||
"//tensorflow:windows": ["/w"],
|
||||
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
||||
|
@ -161,9 +166,6 @@ cc_library(
|
|||
"-Wno-sign-compare",
|
||||
"-fvisibility=hidden",
|
||||
],
|
||||
}) + select({
|
||||
"//native_client:tflite": ["-DUSE_TFLITE"],
|
||||
"//conditions:default": ["-UUSE_TFLITE"],
|
||||
}),
|
||||
linkopts = lrt_if_needed() + select({
|
||||
"//tensorflow:macos": [],
|
||||
|
@ -174,64 +176,32 @@ cc_library(
|
|||
# Bazel is has too strong opinions about static linking, so it's
|
||||
# near impossible to get it to link a DLL against another DLL on Windows.
|
||||
# We simply force the linker option manually here as a hacky fix.
|
||||
"//tensorflow:windows": ["bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib"],
|
||||
"//tensorflow:windows": [
|
||||
"bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib",
|
||||
"bazel-out/x64_windows-opt/bin/tensorflow/lite/libtensorflowlite.so.if.lib",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + tflite_linkopts() + DECODER_LINKOPTS,
|
||||
}) + DECODER_LINKOPTS,
|
||||
includes = DECODER_INCLUDES,
|
||||
deps = select({
|
||||
"//native_client:tflite": [
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/tools/evaluation:utils",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//third_party/eigen3",
|
||||
#"//tensorflow/core:all_kernels",
|
||||
### => Trying to be more fine-grained
|
||||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
||||
### CPU only build, libstt.so file size reduced by ~50%
|
||||
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
||||
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
||||
"//tensorflow/core/kernels:cast_op", # Cast
|
||||
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
||||
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
|
||||
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
|
||||
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
||||
"//tensorflow/core/kernels:identity_op", # Identity
|
||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
||||
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
|
||||
"//tensorflow/core/kernels:matmul_op", # MatMul
|
||||
"//tensorflow/core/kernels:reduction_ops", # Max
|
||||
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
||||
"//tensorflow/core/kernels:no_op", # NoOp
|
||||
"//tensorflow/core/kernels:pack_op", # Pack
|
||||
"//tensorflow/core/kernels:sequence_ops", # Range
|
||||
"//tensorflow/core/kernels:relu_op", # Relu
|
||||
"//tensorflow/core/kernels:reshape_op", # Reshape
|
||||
"//tensorflow/core/kernels:softmax_op", # Softmax
|
||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
||||
"//tensorflow/core/kernels:transpose_op", # Transpose
|
||||
"//tensorflow/core/kernels:rnn_ops", # BlockLSTM
|
||||
# And we also need the op libs for these ops used in the model:
|
||||
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
|
||||
"//tensorflow/core:rnn_ops_op_lib", # BlockLSTM
|
||||
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
|
||||
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
|
||||
"//tensorflow/core:no_op_op_lib", # NoOp
|
||||
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
|
||||
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
|
||||
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
|
||||
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
|
||||
],
|
||||
}) + if_cuda([
|
||||
"//tensorflow/core:core",
|
||||
]) + [":kenlm"],
|
||||
deps = [":kenlm", ":tflite"],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
cc_binary(
|
||||
name = "libstt.so",
|
||||
deps = [":coqui_stt_bundle"],
|
||||
linkshared = 1,
|
||||
linkopts = select({
|
||||
"//tensorflow:ios": [
|
||||
"-Wl,-install_name,@rpath/libstt.so",
|
||||
],
|
||||
"//tensorflow:macos": [
|
||||
"-Wl,-install_name,@rpath/libstt.so",
|
||||
],
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,-soname,libstt.so",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
ios_static_framework(
|
||||
|
|
|
@ -20,8 +20,8 @@ endif
|
|||
|
||||
STT_BIN := stt$(PLATFORM_EXE_SUFFIX)
|
||||
CFLAGS_STT := -std=c++11 -o $(STT_BIN)
|
||||
LINK_STT := -lstt -lkenlm
|
||||
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client
|
||||
LINK_STT := -lstt -lkenlm -ltensorflowlite
|
||||
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client -L${TFDIR}/bazel-bin/tensorflow/lite
|
||||
|
||||
ifeq ($(TARGET),host)
|
||||
TOOLCHAIN :=
|
||||
|
@ -61,7 +61,7 @@ TOOL_CC := cl.exe
|
|||
TOOL_CXX := cl.exe
|
||||
TOOL_LD := link.exe
|
||||
TOOL_LIBEXE := lib.exe
|
||||
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib")
|
||||
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib")
|
||||
LINK_PATH_STT :=
|
||||
CFLAGS_STT := -nologo -Fe$(STT_BIN)
|
||||
SOX_CFLAGS :=
|
||||
|
@ -185,7 +185,7 @@ define copy_missing_libs
|
|||
new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \
|
||||
missing_libs="$$missing_libs $$new_missing"; \
|
||||
elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \
|
||||
missing_libs="libstt.so libkenlm.so"; \
|
||||
missing_libs="libstt.so libkenlm.so libtensorflowlite.so"; \
|
||||
else \
|
||||
missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \
|
||||
fi; \
|
||||
|
|
|
@ -27,12 +27,14 @@
|
|||
"libraries": [
|
||||
"../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib",
|
||||
"../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib",
|
||||
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib",
|
||||
],
|
||||
},
|
||||
{
|
||||
"libraries": [
|
||||
"../../../tensorflow/bazel-bin/native_client/libstt.so",
|
||||
"../../../tensorflow/bazel-bin/native_client/libkenlm.so",
|
||||
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so",
|
||||
],
|
||||
},
|
||||
],
|
||||
|
|
|
@ -14,13 +14,7 @@
|
|||
#include "modelstate.h"
|
||||
|
||||
#include "workspace_status.h"
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
#include "tfmodelstate.h"
|
||||
#else
|
||||
#include "tflitemodelstate.h"
|
||||
#endif // USE_TFLITE
|
||||
|
||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||
|
||||
#ifdef __ANDROID__
|
||||
|
@ -282,13 +276,7 @@ STT_CreateModel(const char* aModelPath,
|
|||
return STT_ERR_NO_MODEL;
|
||||
}
|
||||
|
||||
std::unique_ptr<ModelState> model(
|
||||
#ifndef USE_TFLITE
|
||||
new TFModelState()
|
||||
#else
|
||||
new TFLiteModelState()
|
||||
#endif
|
||||
);
|
||||
std::unique_ptr<ModelState> model(new TFLiteModelState());
|
||||
|
||||
if (!model) {
|
||||
std::cerr << "Could not allocate model state." << std::endl;
|
||||
|
|
|
@ -1,263 +0,0 @@
|
|||
#include "tfmodelstate.h"
|
||||
|
||||
#include "workspace_status.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using std::vector;
|
||||
|
||||
TFModelState::TFModelState()
|
||||
: ModelState()
|
||||
, mmap_env_(nullptr)
|
||||
, session_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
TFModelState::~TFModelState()
|
||||
{
|
||||
if (session_) {
|
||||
Status status = session_->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::init(const char* model_path)
|
||||
{
|
||||
int err = ModelState::init(model_path);
|
||||
if (err != STT_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
|
||||
mmap_env_.reset(new MemmappedEnv(Env::Default()));
|
||||
|
||||
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
|
||||
if (!is_mmap) {
|
||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||
} else {
|
||||
status = mmap_env_->InitializeFromFile(model_path);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::OptimizerOptions::L0);
|
||||
options.env = mmap_env_.get();
|
||||
}
|
||||
|
||||
Session* session;
|
||||
status = NewSession(options, &session);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_INIT_SESS;
|
||||
}
|
||||
session_.reset(session);
|
||||
|
||||
if (is_mmap) {
|
||||
status = ReadBinaryProto(mmap_env_.get(),
|
||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&graph_def_);
|
||||
} else {
|
||||
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_READ_PROTOBUF;
|
||||
}
|
||||
|
||||
status = session_->Create(graph_def_);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return STT_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
std::vector<tensorflow::Tensor> version_output;
|
||||
status = session_->Run({}, {
|
||||
"metadata_version"
|
||||
}, {}, &version_output);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Unable to fetch graph version: " << status << std::endl;
|
||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
int graph_version = version_output[0].scalar<int>()();
|
||||
if (graph_version < ds_graph_version()) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< ds_graph_version() << "). See "
|
||||
<< "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
std::vector<tensorflow::Tensor> metadata_outputs;
|
||||
status = session_->Run({}, {
|
||||
"metadata_sample_rate",
|
||||
"metadata_feature_win_len",
|
||||
"metadata_feature_win_step",
|
||||
"metadata_beam_width",
|
||||
"metadata_alphabet",
|
||||
}, {}, &metadata_outputs);
|
||||
if (!status.ok()) {
|
||||
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
sample_rate_ = metadata_outputs[0].scalar<int>()();
|
||||
int win_len_ms = metadata_outputs[1].scalar<int>()();
|
||||
int win_step_ms = metadata_outputs[2].scalar<int>()();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
int beam_width = metadata_outputs[3].scalar<int>()();
|
||||
beam_width_ = (unsigned int)(beam_width);
|
||||
|
||||
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
||||
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||
if (err != 0) {
|
||||
return STT_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
||||
assert(sample_rate_ > 0);
|
||||
assert(audio_win_len_ > 0);
|
||||
assert(audio_win_step_ > 0);
|
||||
assert(beam_width_ > 0);
|
||||
assert(alphabet_.GetSize() > 0);
|
||||
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
NodeDef node = graph_def_.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
n_steps_ = shape.dim(1).size();
|
||||
n_context_ = (shape.dim(2).size()-1)/2;
|
||||
n_features_ = shape.dim(3).size();
|
||||
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "previous_state_c") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
state_size_ = shape.dim(1).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != alphabet_.GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_.GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return STT_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (n_context_ == -1 || n_features_ == -1) {
|
||||
std::cerr << "Error: Could not infer input shape from model file. "
|
||||
<< "Make sure input_node is a 4D tensor with shape "
|
||||
<< "[batch_size=1, time, window_size, n_features]."
|
||||
<< std::endl;
|
||||
return STT_ERR_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
return STT_ERR_OK;
|
||||
}
|
||||
|
||||
Tensor
|
||||
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
|
||||
{
|
||||
Tensor ret(DT_FLOAT, shape);
|
||||
auto ret_mapped = ret.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < vec.size(); ++i) {
|
||||
ret_mapped(i) = vec[i];
|
||||
}
|
||||
for (; i < shape.num_elements(); ++i) {
|
||||
ret_mapped(i) = 0.f;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
|
||||
{
|
||||
auto tensor_mapped = tensor.flat<float>();
|
||||
if (num_elements == -1) {
|
||||
num_elements = tensor.shape().num_elements();
|
||||
}
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
vec.push_back(tensor_mapped(i));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
vector<float>& logits_output,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
||||
|
||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run(
|
||||
{
|
||||
{"input_node", input},
|
||||
{"input_lengths", input_lengths},
|
||||
{"previous_state_c", previous_state_c_t},
|
||||
{"previous_state_h", previous_state_h_t}
|
||||
},
|
||||
{"logits", "new_state_c", "new_state_h"},
|
||||
{},
|
||||
&outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
|
||||
|
||||
state_c_output.clear();
|
||||
state_c_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[1], state_c_output);
|
||||
|
||||
state_h_output.clear();
|
||||
state_h_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[2], state_h_output);
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
||||
copy_tensor_to_vector(outputs[0], mfcc_output);
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
#ifndef TFMODELSTATE_H
|
||||
#define TFMODELSTATE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
struct TFModelState : public ModelState
|
||||
{
|
||||
std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
|
||||
std::unique_ptr<tensorflow::Session> session_;
|
||||
tensorflow::GraphDef graph_def_;
|
||||
|
||||
TFModelState();
|
||||
virtual ~TFModelState();
|
||||
|
||||
virtual int init(const char* model_path) override;
|
||||
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
};
|
||||
|
||||
#endif // TFMODELSTATE_H
|
Loading…
Reference in New Issue