Remove full TF backend
This commit is contained in:
parent
2020f1b15a
commit
42ebbf9120
|
@ -5,11 +5,8 @@ inputs:
|
||||||
description: "Target arch for loading script (host/armv7/aarch64)"
|
description: "Target arch for loading script (host/armv7/aarch64)"
|
||||||
required: false
|
required: false
|
||||||
default: "host"
|
default: "host"
|
||||||
flavor:
|
|
||||||
description: "Build flavor"
|
|
||||||
required: true
|
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- run: ./ci_scripts/${{ inputs.arch }}-build.sh ${{ inputs.flavor }}
|
- run: ./ci_scripts/${{ inputs.arch }}-build.sh
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
name: "Python binding"
|
name: "Python binding"
|
||||||
description: "Binding a python binding"
|
description: "Binding a python binding"
|
||||||
inputs:
|
inputs:
|
||||||
build_flavor:
|
|
||||||
description: "Python package name"
|
|
||||||
required: true
|
|
||||||
numpy_build:
|
numpy_build:
|
||||||
description: "NumPy build dependecy"
|
description: "NumPy build dependecy"
|
||||||
required: true
|
required: true
|
||||||
|
@ -46,9 +43,6 @@ runs:
|
||||||
set -xe
|
set -xe
|
||||||
|
|
||||||
PROJECT_NAME="stt"
|
PROJECT_NAME="stt"
|
||||||
if [ "${{ inputs.build_flavor }}" = "tflite" ]; then
|
|
||||||
PROJECT_NAME="stt-tflite"
|
|
||||||
fi
|
|
||||||
|
|
||||||
OS=$(uname)
|
OS=$(uname)
|
||||||
if [ "${OS}" = "Linux" ]; then
|
if [ "${OS}" = "Linux" ]; then
|
||||||
|
|
|
@ -4,9 +4,6 @@ inputs:
|
||||||
runtime:
|
runtime:
|
||||||
description: "Runtime to use for running test"
|
description: "Runtime to use for running test"
|
||||||
required: true
|
required: true
|
||||||
build-flavor:
|
|
||||||
description: "Running against TF or TFLite"
|
|
||||||
required: true
|
|
||||||
model-kind:
|
model-kind:
|
||||||
description: "Running against CI baked or production model"
|
description: "Running against CI baked or production model"
|
||||||
required: true
|
required: true
|
||||||
|
@ -22,10 +19,7 @@ runs:
|
||||||
- run: |
|
- run: |
|
||||||
set -xe
|
set -xe
|
||||||
|
|
||||||
build=""
|
|
||||||
if [ "${{ inputs.build-flavor }}" = "tflite" ]; then
|
|
||||||
build="_tflite"
|
build="_tflite"
|
||||||
fi
|
|
||||||
|
|
||||||
model_kind=""
|
model_kind=""
|
||||||
if [ "${{ inputs.model-kind }}" = "prod" ]; then
|
if [ "${{ inputs.model-kind }}" = "prod" ]; then
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -55,23 +55,6 @@ maybe_install_xldd()
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# Checks whether we run a patched version of bazel.
|
|
||||||
# Patching is required to dump computeKey() parameters to .ckd files
|
|
||||||
# See bazel.patch
|
|
||||||
# Return 0 (success exit code) on patched version, 1 on release version
|
|
||||||
is_patched_bazel()
|
|
||||||
{
|
|
||||||
bazel_version=$(bazel version | grep 'Build label:' | cut -d':' -f2)
|
|
||||||
|
|
||||||
bazel shutdown
|
|
||||||
|
|
||||||
if [ -z "${bazel_version}" ]; then
|
|
||||||
return 0;
|
|
||||||
else
|
|
||||||
return 1;
|
|
||||||
fi;
|
|
||||||
}
|
|
||||||
|
|
||||||
verify_bazel_rebuild()
|
verify_bazel_rebuild()
|
||||||
{
|
{
|
||||||
bazel_explain_file="$1"
|
bazel_explain_file="$1"
|
||||||
|
|
|
@ -9,21 +9,15 @@ do_bazel_build()
|
||||||
cd ${DS_TFDIR}
|
cd ${DS_TFDIR}
|
||||||
eval "export ${BAZEL_ENV_FLAGS}"
|
eval "export ${BAZEL_ENV_FLAGS}"
|
||||||
|
|
||||||
if [ "${_opt_or_dbg}" = "opt" ]; then
|
|
||||||
if is_patched_bazel; then
|
|
||||||
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-tf.tar -T -
|
|
||||||
fi;
|
|
||||||
fi;
|
|
||||||
|
|
||||||
bazel ${BAZEL_OUTPUT_USER_ROOT} build \
|
bazel ${BAZEL_OUTPUT_USER_ROOT} build \
|
||||||
-s --explain bazel_monolithic.log --verbose_explanations --experimental_strict_action_env --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
|
-s --explain bazel_explain.log --verbose_explanations \
|
||||||
|
--experimental_strict_action_env \
|
||||||
|
--workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" \
|
||||||
|
-c ${_opt_or_dbg} ${BAZEL_BUILD_FLAGS} ${BAZEL_TARGETS}
|
||||||
|
|
||||||
if [ "${_opt_or_dbg}" = "opt" ]; then
|
if [ "${_opt_or_dbg}" = "opt" ]; then
|
||||||
if is_patched_bazel; then
|
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_explain.log"
|
||||||
find ${DS_ROOT_TASK}/tensorflow/bazel-out/ -iname "*.ckd" | tar -cf ${DS_ROOT_TASK}/bazel-ckd-ds.tar -T -
|
fi
|
||||||
fi;
|
|
||||||
verify_bazel_rebuild "${DS_ROOT_TASK}/tensorflow/bazel_monolithic.log"
|
|
||||||
fi;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
shutdown_bazel()
|
shutdown_bazel()
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
model_source=${STT_PROD_MODEL}
|
|
||||||
model_name=$(basename "${model_source}")
|
|
||||||
|
|
||||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
|
||||||
model_name_mmap=$(basename "${model_source_mmap}")
|
|
||||||
|
|
||||||
download_model_prod
|
|
||||||
|
|
||||||
download_material
|
|
||||||
|
|
||||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
|
||||||
|
|
||||||
check_versions
|
|
||||||
|
|
||||||
run_prod_inference_tests "${bitrate}"
|
|
|
@ -1,24 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
download_data
|
|
||||||
|
|
||||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
|
||||||
|
|
||||||
check_versions
|
|
||||||
|
|
||||||
run_all_inference_tests
|
|
||||||
|
|
||||||
run_multi_inference_tests
|
|
||||||
|
|
||||||
run_cpp_only_inference_tests
|
|
||||||
|
|
||||||
run_hotword_tests
|
|
|
@ -1,20 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
download_material "${CI_TMP_DIR}/ds"
|
|
||||||
|
|
||||||
export PATH=${CI_TMP_DIR}/ds/:$PATH
|
|
||||||
|
|
||||||
check_versions
|
|
||||||
|
|
||||||
ensure_cuda_usage "$2"
|
|
||||||
|
|
||||||
run_basic_inference_tests
|
|
|
@ -1,48 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
model_source=${STT_PROD_MODEL}
|
|
||||||
model_name=$(basename "${model_source}")
|
|
||||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
|
||||||
model_name_mmap=$(basename "${model_source_mmap}")
|
|
||||||
|
|
||||||
download_model_prod
|
|
||||||
|
|
||||||
download_data
|
|
||||||
|
|
||||||
node --version
|
|
||||||
npm --version
|
|
||||||
|
|
||||||
symlink_electron
|
|
||||||
|
|
||||||
export_node_bin_path
|
|
||||||
|
|
||||||
which electron
|
|
||||||
which node
|
|
||||||
|
|
||||||
if [ "${OS}" = "Linux" ]; then
|
|
||||||
export DISPLAY=':99.0'
|
|
||||||
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
|
|
||||||
xvfb_process=$!
|
|
||||||
fi
|
|
||||||
|
|
||||||
node --version
|
|
||||||
|
|
||||||
stt --version
|
|
||||||
|
|
||||||
check_runtime_electronjs
|
|
||||||
|
|
||||||
run_electronjs_prod_inference_tests "${bitrate}"
|
|
||||||
|
|
||||||
if [ "${OS}" = "Linux" ]; then
|
|
||||||
sleep 1
|
|
||||||
sudo kill -9 ${xvfb_process} || true
|
|
||||||
fi
|
|
|
@ -1,41 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
download_data
|
|
||||||
|
|
||||||
node --version
|
|
||||||
npm --version
|
|
||||||
|
|
||||||
symlink_electron
|
|
||||||
|
|
||||||
export_node_bin_path
|
|
||||||
|
|
||||||
which electron
|
|
||||||
which node
|
|
||||||
|
|
||||||
if [ "${OS}" = "Linux" ]; then
|
|
||||||
export DISPLAY=':99.0'
|
|
||||||
sudo Xvfb :99 -screen 0 1024x768x24 > /dev/null 2>&1 &
|
|
||||||
xvfb_process=$!
|
|
||||||
fi
|
|
||||||
|
|
||||||
node --version
|
|
||||||
|
|
||||||
stt --version
|
|
||||||
|
|
||||||
check_runtime_electronjs
|
|
||||||
|
|
||||||
run_electronjs_inference_tests
|
|
||||||
|
|
||||||
if [ "${OS}" = "Linux" ]; then
|
|
||||||
sleep 1
|
|
||||||
sudo kill -9 ${xvfb_process} || true
|
|
||||||
fi
|
|
|
@ -2,8 +2,6 @@
|
||||||
|
|
||||||
set -xe
|
set -xe
|
||||||
|
|
||||||
runtime=$1
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
source $(dirname "$0")/all-vars.sh
|
||||||
source $(dirname "$0")/all-utils.sh
|
source $(dirname "$0")/all-utils.sh
|
||||||
source $(dirname "$0")/build-utils.sh
|
source $(dirname "$0")/build-utils.sh
|
||||||
|
@ -15,10 +13,7 @@ BAZEL_TARGETS="
|
||||||
//native_client:generate_scorer_package
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
if [ "${runtime}" = "tflite" ]; then
|
BAZEL_BUILD_FLAGS="${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||||
BAZEL_BUILD_TFLITE="--define=runtime=tflite"
|
|
||||||
fi;
|
|
||||||
BAZEL_BUILD_FLAGS="${BAZEL_BUILD_TFLITE} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
|
||||||
|
|
||||||
BAZEL_ENV_FLAGS="TF_NEED_CUDA=0"
|
BAZEL_ENV_FLAGS="TF_NEED_CUDA=0"
|
||||||
SYSTEM_TARGET=host
|
SYSTEM_TARGET=host
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
model_source=${STT_PROD_MODEL}
|
|
||||||
model_name=$(basename "${model_source}")
|
|
||||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
|
||||||
model_name_mmap=$(basename "${model_source_mmap}")
|
|
||||||
|
|
||||||
download_model_prod
|
|
||||||
|
|
||||||
download_data
|
|
||||||
|
|
||||||
node --version
|
|
||||||
npm --version
|
|
||||||
|
|
||||||
export_node_bin_path
|
|
||||||
|
|
||||||
check_runtime_nodejs
|
|
||||||
|
|
||||||
run_prod_inference_tests "${bitrate}"
|
|
||||||
|
|
||||||
run_js_streaming_prod_inference_tests "${bitrate}"
|
|
|
@ -1,25 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
download_data
|
|
||||||
|
|
||||||
node --version
|
|
||||||
npm --version
|
|
||||||
|
|
||||||
export_node_bin_path
|
|
||||||
|
|
||||||
check_runtime_nodejs
|
|
||||||
|
|
||||||
run_all_inference_tests
|
|
||||||
|
|
||||||
run_js_streaming_inference_tests
|
|
||||||
|
|
||||||
run_hotword_tests
|
|
|
@ -30,10 +30,15 @@ package_native_client()
|
||||||
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib"
|
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so.if.lib"
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
|
if [ -f "${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib" ]; then
|
||||||
|
win_lib="$win_lib -C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so.if.lib"
|
||||||
|
fi;
|
||||||
|
|
||||||
${TAR} --verbose -cf - \
|
${TAR} --verbose -cf - \
|
||||||
--transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \
|
--transform='flags=r;s|README.coqui|KenLM_License_Info.txt|' \
|
||||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \
|
-C ${tensorflow_dir}/bazel-bin/native_client/ libstt.so \
|
||||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \
|
-C ${tensorflow_dir}/bazel-bin/native_client/ libkenlm.so \
|
||||||
|
-C ${tensorflow_dir}/bazel-bin/tensorflow/lite/ libtensorflowlite.so \
|
||||||
${win_lib} \
|
${win_lib} \
|
||||||
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
|
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
|
||||||
-C ${stt_dir}/ LICENSE \
|
-C ${stt_dir}/ LICENSE \
|
||||||
|
@ -94,5 +99,8 @@ package_libstt_as_zip()
|
||||||
echo "Please specify artifact name."
|
echo "Please specify artifact name."
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" ${tensorflow_dir}/bazel-bin/native_client/libstt.so
|
${ZIP} -r9 --junk-paths "${artifacts_dir}/${artifact_name}" \
|
||||||
|
${tensorflow_dir}/bazel-bin/native_client/libstt.so \
|
||||||
|
${tensorflow_dir}/bazel-bin/native_client/libkenlm.so \
|
||||||
|
${tensorflow_dir}/bazel-bin/tensorflow/lite/libtensorflowlite.so
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
model_source=${STT_PROD_MODEL}
|
|
||||||
model_name=$(basename "${model_source}")
|
|
||||||
|
|
||||||
model_source_mmap=${STT_PROD_MODEL_MMAP}
|
|
||||||
model_name_mmap=$(basename "${model_source_mmap}")
|
|
||||||
|
|
||||||
download_model_prod
|
|
||||||
|
|
||||||
download_material
|
|
||||||
|
|
||||||
export_py_bin_path
|
|
||||||
|
|
||||||
which stt
|
|
||||||
stt --version
|
|
||||||
|
|
||||||
run_prod_inference_tests "${bitrate}"
|
|
||||||
|
|
||||||
run_prod_concurrent_stream_tests "${bitrate}"
|
|
|
@ -1,21 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
|
|
||||||
source $(dirname "$0")/all-vars.sh
|
|
||||||
source $(dirname "$0")/all-utils.sh
|
|
||||||
source $(dirname "$0")/asserts.sh
|
|
||||||
|
|
||||||
bitrate=$1
|
|
||||||
set_ldc_sample_filename "${bitrate}"
|
|
||||||
|
|
||||||
download_data
|
|
||||||
|
|
||||||
export_py_bin_path
|
|
||||||
|
|
||||||
which stt
|
|
||||||
stt --version
|
|
||||||
|
|
||||||
run_all_inference_tests
|
|
||||||
|
|
||||||
run_hotword_tests
|
|
|
@ -6,7 +6,7 @@ set -o pipefail
|
||||||
source $(dirname $0)/tf-vars.sh
|
source $(dirname $0)/tf-vars.sh
|
||||||
|
|
||||||
pushd ${DS_ROOT_TASK}/tensorflow/
|
pushd ${DS_ROOT_TASK}/tensorflow/
|
||||||
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s --explain bazel_monolithic_tf.log --verbose_explanations --experimental_strict_action_env --config=monolithic"
|
BAZEL_BUILD="bazel ${BAZEL_OUTPUT_USER_ROOT} build -s"
|
||||||
|
|
||||||
# Start a bazel process to ensure reliability on Windows and avoid:
|
# Start a bazel process to ensure reliability on Windows and avoid:
|
||||||
# FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing.
|
# FATAL: corrupt installation: file 'c:\builds\tc-workdir\.bazel_cache/install/6b1660721930e9d5f231f7d2a626209b/_embedded_binaries/build-runfiles.exe' missing.
|
||||||
|
@ -23,13 +23,10 @@ pushd ${DS_ROOT_TASK}/tensorflow/
|
||||||
|
|
||||||
case "$1" in
|
case "$1" in
|
||||||
"--windows-cpu")
|
"--windows-cpu")
|
||||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIBSTT} ${BUILD_TARGET_LITE_LIB} --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh"
|
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||||
;;
|
;;
|
||||||
"--linux-cpu"|"--darwin-cpu")
|
"--linux-cpu"|"--darwin-cpu")
|
||||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LIB_CPP_API} ${BUILD_TARGET_LITE_LIB}
|
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_OPT_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||||
;;
|
|
||||||
"--linux-cuda"|"--windows-cuda")
|
|
||||||
eval "export ${TF_CUDA_FLAGS}" && (echo "" | TF_NEED_CUDA=1 ./configure) && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_CUDA_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BAZEL_OPT_FLAGS} ${BUILD_TARGET_LIB_CPP_API}
|
|
||||||
;;
|
;;
|
||||||
"--linux-armv7")
|
"--linux-armv7")
|
||||||
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
echo "" | TF_NEED_CUDA=0 ./configure && ${BAZEL_BUILD} ${OPT_OR_DBG} ${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS} ${BUILD_TARGET_LITE_LIB}
|
||||||
|
|
|
@ -6,26 +6,17 @@ source $(dirname $0)/tf-vars.sh
|
||||||
|
|
||||||
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
mkdir -p ${CI_ARTIFACTS_DIR} || true
|
||||||
|
|
||||||
cp ${DS_ROOT_TASK}/tensorflow/bazel_*.log ${CI_ARTIFACTS_DIR} || true
|
|
||||||
|
|
||||||
OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin"
|
OUTPUT_ROOT="${DS_ROOT_TASK}/tensorflow/bazel-bin"
|
||||||
|
|
||||||
for output_bin in \
|
for output_bin in \
|
||||||
tensorflow/lite/experimental/c/libtensorflowlite_c.so \
|
tensorflow/lite/libtensorflow.so \
|
||||||
tensorflow/tools/graph_transforms/transform_graph \
|
tensorflow/lite/libtensorflow.so.if.lib \
|
||||||
tensorflow/tools/graph_transforms/summarize_graph \
|
;
|
||||||
tensorflow/tools/benchmark/benchmark_model \
|
|
||||||
tensorflow/contrib/util/convert_graphdef_memmapped_format \
|
|
||||||
tensorflow/lite/toco/toco;
|
|
||||||
do
|
do
|
||||||
if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then
|
if [ -f "${OUTPUT_ROOT}/${output_bin}" ]; then
|
||||||
cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/
|
cp ${OUTPUT_ROOT}/${output_bin} ${CI_ARTIFACTS_DIR}/
|
||||||
fi;
|
fi;
|
||||||
done;
|
done
|
||||||
|
|
||||||
if [ -f "${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model" ]; then
|
|
||||||
cp ${OUTPUT_ROOT}/tensorflow/lite/tools/benchmark/benchmark_model ${CI_ARTIFACTS_DIR}/lite_benchmark_model
|
|
||||||
fi
|
|
||||||
|
|
||||||
# It seems that bsdtar and gnutar are behaving a bit differently on the way
|
# It seems that bsdtar and gnutar are behaving a bit differently on the way
|
||||||
# they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/
|
# they deal with --exclude="./public/*" ; this caused ./STT/tensorflow/core/public/
|
||||||
|
|
|
@ -5,12 +5,7 @@ set -ex
|
||||||
source $(dirname $0)/tf-vars.sh
|
source $(dirname $0)/tf-vars.sh
|
||||||
|
|
||||||
install_android=
|
install_android=
|
||||||
install_cuda=
|
|
||||||
case "$1" in
|
case "$1" in
|
||||||
"--linux-cuda"|"--windows-cuda")
|
|
||||||
install_cuda=yes
|
|
||||||
;;
|
|
||||||
|
|
||||||
"--android-armv7"|"--android-arm64")
|
"--android-armv7"|"--android-arm64")
|
||||||
install_android=yes
|
install_android=yes
|
||||||
;;
|
;;
|
||||||
|
@ -29,11 +24,6 @@ download()
|
||||||
mkdir -p ${DS_ROOT_TASK}/dls || true
|
mkdir -p ${DS_ROOT_TASK}/dls || true
|
||||||
download $BAZEL_URL $BAZEL_SHA256
|
download $BAZEL_URL $BAZEL_SHA256
|
||||||
|
|
||||||
if [ ! -z "${install_cuda}" ]; then
|
|
||||||
download $CUDA_URL $CUDA_SHA256
|
|
||||||
download $CUDNN_URL $CUDNN_SHA256
|
|
||||||
fi;
|
|
||||||
|
|
||||||
if [ ! -z "${install_android}" ]; then
|
if [ ! -z "${install_android}" ]; then
|
||||||
download $ANDROID_NDK_URL $ANDROID_NDK_SHA256
|
download $ANDROID_NDK_URL $ANDROID_NDK_SHA256
|
||||||
download $ANDROID_SDK_URL $ANDROID_SDK_SHA256
|
download $ANDROID_SDK_URL $ANDROID_SDK_SHA256
|
||||||
|
@ -63,30 +53,6 @@ bazel version
|
||||||
|
|
||||||
bazel shutdown
|
bazel shutdown
|
||||||
|
|
||||||
if [ ! -z "${install_cuda}" ]; then
|
|
||||||
# Install CUDA and CuDNN
|
|
||||||
mkdir -p ${DS_ROOT_TASK}/STT/CUDA/ || true
|
|
||||||
pushd ${DS_ROOT_TASK}
|
|
||||||
CUDA_FILE=`basename ${CUDA_URL}`
|
|
||||||
PERL5LIB=. sh ${DS_ROOT_TASK}/dls/${CUDA_FILE} --silent --override --toolkit --toolkitpath=${DS_ROOT_TASK}/STT/CUDA/ --defaultroot=${DS_ROOT_TASK}/STT/CUDA/
|
|
||||||
|
|
||||||
CUDNN_FILE=`basename ${CUDNN_URL}`
|
|
||||||
tar xvf ${DS_ROOT_TASK}/dls/${CUDNN_FILE} --strip-components=1 -C ${DS_ROOT_TASK}/STT/CUDA/
|
|
||||||
popd
|
|
||||||
|
|
||||||
LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
|
|
||||||
export LD_LIBRARY_PATH
|
|
||||||
|
|
||||||
# We might lack libcuda.so.1 symlink, let's fix as upstream does:
|
|
||||||
# https://github.com/tensorflow/tensorflow/pull/13811/files?diff=split#diff-2352449eb75e66016e97a591d3f0f43dR96
|
|
||||||
if [ ! -h "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1" ]; then
|
|
||||||
ln -s "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so" "${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/libcuda.so.1"
|
|
||||||
fi;
|
|
||||||
|
|
||||||
else
|
|
||||||
echo "No CUDA/CuDNN to install"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -z "${install_android}" ]; then
|
if [ ! -z "${install_android}" ]; then
|
||||||
mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true
|
mkdir -p ${DS_ROOT_TASK}/STT/Android/SDK || true
|
||||||
ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}`
|
ANDROID_NDK_FILE=`basename ${ANDROID_NDK_URL}`
|
||||||
|
|
|
@ -9,13 +9,6 @@ if [ "${OS}" = "Linux" ]; then
|
||||||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh
|
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh
|
||||||
BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed
|
BAZEL_SHA256=7ba815cbac712d061fe728fef958651512ff394b2708e89f79586ec93d1185ed
|
||||||
|
|
||||||
CUDA_URL=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run
|
|
||||||
CUDA_SHA256=e7c22dc21278eb1b82f34a60ad7640b41ad3943d929bebda3008b72536855d31
|
|
||||||
|
|
||||||
# From https://gitlab.com/nvidia/cuda/blob/centos7/10.1/devel/cudnn7/Dockerfile
|
|
||||||
CUDNN_URL=http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.1-linux-x64-v7.6.0.64.tgz
|
|
||||||
CUDNN_SHA256=e956c6f9222fcb867a10449cfc76dee5cfd7c7531021d95fe9586d7e043b57d7
|
|
||||||
|
|
||||||
ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip
|
ANDROID_NDK_URL=https://dl.google.com/android/repository/android-ndk-r18b-linux-x86_64.zip
|
||||||
ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd
|
ANDROID_NDK_SHA256=4f61cbe4bbf6406aa5ef2ae871def78010eed6271af72de83f8bd0b07a9fd3fd
|
||||||
|
|
||||||
|
@ -48,8 +41,6 @@ elif [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
||||||
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe
|
BAZEL_URL=https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-windows-x86_64.exe
|
||||||
BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9
|
BAZEL_SHA256=776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9
|
||||||
|
|
||||||
CUDA_INSTALL_DIRECTORY=$(cygpath 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1')
|
|
||||||
|
|
||||||
TAR=/usr/bin/tar.exe
|
TAR=/usr/bin/tar.exe
|
||||||
elif [ "${OS}" = "Darwin" ]; then
|
elif [ "${OS}" = "Darwin" ]; then
|
||||||
if [ -z "${CI_TASK_DIR}" -o -z "${CI_ARTIFACTS_DIR}" ]; then
|
if [ -z "${CI_TASK_DIR}" -o -z "${CI_ARTIFACTS_DIR}" ]; then
|
||||||
|
@ -89,7 +80,6 @@ fi;
|
||||||
export PATH
|
export PATH
|
||||||
|
|
||||||
if [ "${OS}" = "Linux" ]; then
|
if [ "${OS}" = "Linux" ]; then
|
||||||
export LD_LIBRARY_PATH=${DS_ROOT_TASK}/STT/CUDA/lib64/:${DS_ROOT_TASK}/STT/CUDA/lib64/stubs/:$LD_LIBRARY_PATH
|
|
||||||
export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/
|
export ANDROID_SDK_HOME=${DS_ROOT_TASK}/STT/Android/SDK/
|
||||||
export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/
|
export ANDROID_NDK_HOME=${DS_ROOT_TASK}/STT/Android/android-ndk-r18b/
|
||||||
fi;
|
fi;
|
||||||
|
@ -160,27 +150,15 @@ export BAZEL_OUTPUT_USER_ROOT
|
||||||
|
|
||||||
NVCC_COMPUTE="3.5"
|
NVCC_COMPUTE="3.5"
|
||||||
|
|
||||||
### Define build parameters/env variables that we will re-ues in sourcing scripts.
|
|
||||||
if [ "${OS}" = "${CI_MSYS_VERSION}" ]; then
|
|
||||||
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_PATHS=\"${CUDA_INSTALL_DIRECTORY}\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
|
|
||||||
else
|
|
||||||
TF_CUDA_FLAGS="TF_CUDA_CLANG=0 TF_CUDA_VERSION=10.1 TF_CUDNN_VERSION=7.6.0 CUDNN_INSTALL_PATH=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_PATHS=\"${DS_ROOT_TASK}/STT/CUDA\" TF_CUDA_COMPUTE_CAPABILITIES=\"${NVCC_COMPUTE}\""
|
|
||||||
fi
|
|
||||||
BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
BAZEL_ARM_FLAGS="--config=rpi3 --config=rpi3_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||||
BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
BAZEL_ARM64_FLAGS="--config=rpi3-armv8 --config=rpi3-armv8_opt --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||||
BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
BAZEL_ANDROID_ARM_FLAGS="--config=android --config=android_arm --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||||
BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
BAZEL_ANDROID_ARM64_FLAGS="--config=android --config=android_arm64 --action_env ANDROID_NDK_API_LEVEL=21 --cxxopt=-std=c++14 --copt=-D_GLIBCXX_USE_C99 --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||||
BAZEL_CUDA_FLAGS="--config=cuda"
|
|
||||||
if [ "${OS}" = "Linux" ]; then
|
|
||||||
# constexpr usage in tensorflow's absl dep fails badly because of gcc-5
|
|
||||||
# so let's skip that
|
|
||||||
BAZEL_CUDA_FLAGS="${BAZEL_CUDA_FLAGS} --copt=-DNO_CONSTEXPR_FOR_YOU=1"
|
|
||||||
fi
|
|
||||||
BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
BAZEL_IOS_ARM64_FLAGS="--config=ios_arm64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||||
BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
BAZEL_IOS_X86_64_FLAGS="--config=ios_x86_64 --define=runtime=tflite --copt=-DTFLITE_WITH_RUY_GEMV"
|
||||||
|
|
||||||
if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then
|
if [ "${OS}" != "${CI_MSYS_VERSION}" ]; then
|
||||||
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl --copt=-fvisibility=hidden"
|
BAZEL_EXTRA_FLAGS="--config=noaws --config=nogcp --config=nohdfs --config=nonccl"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "${OS}" = "Darwin" ]; then
|
if [ "${OS}" = "Darwin" ]; then
|
||||||
|
@ -189,11 +167,5 @@ fi
|
||||||
|
|
||||||
### Define build targets that we will re-ues in sourcing scripts.
|
### Define build targets that we will re-ues in sourcing scripts.
|
||||||
BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc"
|
BUILD_TARGET_LIB_CPP_API="//tensorflow:tensorflow_cc"
|
||||||
BUILD_TARGET_GRAPH_TRANSFORMS="//tensorflow/tools/graph_transforms:transform_graph"
|
BUILD_TARGET_LITE_LIB="//tensorflow/lite:libtensorflowlite.so"
|
||||||
BUILD_TARGET_GRAPH_SUMMARIZE="//tensorflow/tools/graph_transforms:summarize_graph"
|
|
||||||
BUILD_TARGET_GRAPH_BENCHMARK="//tensorflow/tools/benchmark:benchmark_model"
|
|
||||||
#BUILD_TARGET_CONVERT_MMAP="//tensorflow/contrib/util:convert_graphdef_memmapped_format"
|
|
||||||
BUILD_TARGET_TOCO="//tensorflow/lite/toco:toco"
|
|
||||||
BUILD_TARGET_LITE_BENCHMARK="//tensorflow/lite/tools/benchmark:benchmark_model"
|
|
||||||
BUILD_TARGET_LITE_LIB="//tensorflow/lite/c:libtensorflowlite_c.so"
|
|
||||||
BUILD_TARGET_LIBSTT="//native_client:libstt.so"
|
BUILD_TARGET_LIBSTT="//native_client:libstt.so"
|
||||||
|
|
|
@ -1,22 +1,9 @@
|
||||||
# Description: Coqui STT native client library.
|
# Description: Coqui STT native client library.
|
||||||
|
|
||||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object", "tf_copts", "lrt_if_needed")
|
load("@org_tensorflow//tensorflow:tensorflow.bzl", "lrt_if_needed")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
|
||||||
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
|
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework")
|
||||||
|
|
||||||
load(
|
|
||||||
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
|
||||||
"tflite_copts",
|
|
||||||
"tflite_linkopts",
|
|
||||||
)
|
|
||||||
|
|
||||||
config_setting(
|
|
||||||
name = "tflite",
|
|
||||||
define_values = {
|
|
||||||
"runtime": "tflite",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "rpi3",
|
name = "rpi3",
|
||||||
|
@ -85,7 +72,7 @@ LINUX_LINKOPTS = [
|
||||||
"-Wl,-export-dynamic",
|
"-Wl,-export-dynamic",
|
||||||
]
|
]
|
||||||
|
|
||||||
tf_cc_shared_object(
|
cc_binary(
|
||||||
name = "libkenlm.so",
|
name = "libkenlm.so",
|
||||||
srcs = glob([
|
srcs = glob([
|
||||||
"kenlm/lm/*.hh",
|
"kenlm/lm/*.hh",
|
||||||
|
@ -107,8 +94,19 @@ tf_cc_shared_object(
|
||||||
}),
|
}),
|
||||||
defines = ["KENLM_MAX_ORDER=6"],
|
defines = ["KENLM_MAX_ORDER=6"],
|
||||||
includes = ["kenlm"],
|
includes = ["kenlm"],
|
||||||
framework_so = [],
|
linkshared = 1,
|
||||||
linkopts = [],
|
linkopts = select({
|
||||||
|
"//tensorflow:ios": [
|
||||||
|
"-Wl,-install_name,@rpath/libkenlm.so",
|
||||||
|
],
|
||||||
|
"//tensorflow:macos": [
|
||||||
|
"-Wl,-install_name,@rpath/libkenlm.so",
|
||||||
|
],
|
||||||
|
"//tensorflow:windows": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
"-Wl,-soname,libkenlm.so",
|
||||||
|
],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -132,6 +130,20 @@ cc_library(
|
||||||
copts = ["-fexceptions"],
|
copts = ["-fexceptions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name="tflite",
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/lite:model.h",
|
||||||
|
"//tensorflow/lite/kernels:register.h",
|
||||||
|
"//tensorflow/lite/tools/evaluation:utils.h",
|
||||||
|
],
|
||||||
|
srcs = [
|
||||||
|
"//tensorflow/lite:libtensorflowlite.so",
|
||||||
|
],
|
||||||
|
includes = ["tensorflow"],
|
||||||
|
deps = ["//tensorflow/lite:libtensorflowlite.so"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "coqui_stt_bundle",
|
name = "coqui_stt_bundle",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -142,17 +154,10 @@ cc_library(
|
||||||
"modelstate.h",
|
"modelstate.h",
|
||||||
"workspace_status.cc",
|
"workspace_status.cc",
|
||||||
"workspace_status.h",
|
"workspace_status.h",
|
||||||
] + select({
|
|
||||||
"//native_client:tflite": [
|
|
||||||
"tflitemodelstate.h",
|
"tflitemodelstate.h",
|
||||||
"tflitemodelstate.cc",
|
"tflitemodelstate.cc",
|
||||||
],
|
] + DECODER_SOURCES,
|
||||||
"//conditions:default": [
|
copts = select({
|
||||||
"tfmodelstate.h",
|
|
||||||
"tfmodelstate.cc",
|
|
||||||
],
|
|
||||||
}) + DECODER_SOURCES,
|
|
||||||
copts = tf_copts(allow_exceptions=True) + select({
|
|
||||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||||
"//tensorflow:windows": ["/w"],
|
"//tensorflow:windows": ["/w"],
|
||||||
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
||||||
|
@ -161,9 +166,6 @@ cc_library(
|
||||||
"-Wno-sign-compare",
|
"-Wno-sign-compare",
|
||||||
"-fvisibility=hidden",
|
"-fvisibility=hidden",
|
||||||
],
|
],
|
||||||
}) + select({
|
|
||||||
"//native_client:tflite": ["-DUSE_TFLITE"],
|
|
||||||
"//conditions:default": ["-UUSE_TFLITE"],
|
|
||||||
}),
|
}),
|
||||||
linkopts = lrt_if_needed() + select({
|
linkopts = lrt_if_needed() + select({
|
||||||
"//tensorflow:macos": [],
|
"//tensorflow:macos": [],
|
||||||
|
@ -174,64 +176,32 @@ cc_library(
|
||||||
# Bazel is has too strong opinions about static linking, so it's
|
# Bazel is has too strong opinions about static linking, so it's
|
||||||
# near impossible to get it to link a DLL against another DLL on Windows.
|
# near impossible to get it to link a DLL against another DLL on Windows.
|
||||||
# We simply force the linker option manually here as a hacky fix.
|
# We simply force the linker option manually here as a hacky fix.
|
||||||
"//tensorflow:windows": ["bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib"],
|
"//tensorflow:windows": [
|
||||||
|
"bazel-out/x64_windows-opt/bin/native_client/libkenlm.so.if.lib",
|
||||||
|
"bazel-out/x64_windows-opt/bin/tensorflow/lite/libtensorflowlite.so.if.lib",
|
||||||
|
],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}) + tflite_linkopts() + DECODER_LINKOPTS,
|
}) + DECODER_LINKOPTS,
|
||||||
includes = DECODER_INCLUDES,
|
includes = DECODER_INCLUDES,
|
||||||
deps = select({
|
deps = [":kenlm", ":tflite"],
|
||||||
"//native_client:tflite": [
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"//tensorflow/lite/tools/evaluation:utils",
|
|
||||||
],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:direct_session",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
#"//tensorflow/core:all_kernels",
|
|
||||||
### => Trying to be more fine-grained
|
|
||||||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
|
||||||
### CPU only build, libstt.so file size reduced by ~50%
|
|
||||||
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
|
||||||
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
|
||||||
"//tensorflow/core/kernels:cast_op", # Cast
|
|
||||||
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
|
||||||
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
|
|
||||||
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
|
|
||||||
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
|
||||||
"//tensorflow/core/kernels:identity_op", # Identity
|
|
||||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
|
||||||
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
|
|
||||||
"//tensorflow/core/kernels:matmul_op", # MatMul
|
|
||||||
"//tensorflow/core/kernels:reduction_ops", # Max
|
|
||||||
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
|
||||||
"//tensorflow/core/kernels:no_op", # NoOp
|
|
||||||
"//tensorflow/core/kernels:pack_op", # Pack
|
|
||||||
"//tensorflow/core/kernels:sequence_ops", # Range
|
|
||||||
"//tensorflow/core/kernels:relu_op", # Relu
|
|
||||||
"//tensorflow/core/kernels:reshape_op", # Reshape
|
|
||||||
"//tensorflow/core/kernels:softmax_op", # Softmax
|
|
||||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
|
||||||
"//tensorflow/core/kernels:transpose_op", # Transpose
|
|
||||||
"//tensorflow/core/kernels:rnn_ops", # BlockLSTM
|
|
||||||
# And we also need the op libs for these ops used in the model:
|
|
||||||
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
|
|
||||||
"//tensorflow/core:rnn_ops_op_lib", # BlockLSTM
|
|
||||||
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
|
|
||||||
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
|
|
||||||
"//tensorflow/core:no_op_op_lib", # NoOp
|
|
||||||
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
|
|
||||||
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
|
|
||||||
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
|
|
||||||
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
|
|
||||||
],
|
|
||||||
}) + if_cuda([
|
|
||||||
"//tensorflow/core:core",
|
|
||||||
]) + [":kenlm"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_shared_object(
|
cc_binary(
|
||||||
name = "libstt.so",
|
name = "libstt.so",
|
||||||
deps = [":coqui_stt_bundle"],
|
deps = [":coqui_stt_bundle"],
|
||||||
|
linkshared = 1,
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:ios": [
|
||||||
|
"-Wl,-install_name,@rpath/libstt.so",
|
||||||
|
],
|
||||||
|
"//tensorflow:macos": [
|
||||||
|
"-Wl,-install_name,@rpath/libstt.so",
|
||||||
|
],
|
||||||
|
"//tensorflow:windows": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
"-Wl,-soname,libstt.so",
|
||||||
|
],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
ios_static_framework(
|
ios_static_framework(
|
||||||
|
|
|
@ -20,8 +20,8 @@ endif
|
||||||
|
|
||||||
STT_BIN := stt$(PLATFORM_EXE_SUFFIX)
|
STT_BIN := stt$(PLATFORM_EXE_SUFFIX)
|
||||||
CFLAGS_STT := -std=c++11 -o $(STT_BIN)
|
CFLAGS_STT := -std=c++11 -o $(STT_BIN)
|
||||||
LINK_STT := -lstt -lkenlm
|
LINK_STT := -lstt -lkenlm -ltensorflowlite
|
||||||
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client
|
LINK_PATH_STT := -L${TFDIR}/bazel-bin/native_client -L${TFDIR}/bazel-bin/tensorflow/lite
|
||||||
|
|
||||||
ifeq ($(TARGET),host)
|
ifeq ($(TARGET),host)
|
||||||
TOOLCHAIN :=
|
TOOLCHAIN :=
|
||||||
|
@ -61,7 +61,7 @@ TOOL_CC := cl.exe
|
||||||
TOOL_CXX := cl.exe
|
TOOL_CXX := cl.exe
|
||||||
TOOL_LD := link.exe
|
TOOL_LD := link.exe
|
||||||
TOOL_LIBEXE := lib.exe
|
TOOL_LIBEXE := lib.exe
|
||||||
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib")
|
LINK_STT := $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libstt.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/native_client/libkenlm.so.if.lib") $(shell cygpath "$(TFDIR)/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib")
|
||||||
LINK_PATH_STT :=
|
LINK_PATH_STT :=
|
||||||
CFLAGS_STT := -nologo -Fe$(STT_BIN)
|
CFLAGS_STT := -nologo -Fe$(STT_BIN)
|
||||||
SOX_CFLAGS :=
|
SOX_CFLAGS :=
|
||||||
|
@ -185,7 +185,7 @@ define copy_missing_libs
|
||||||
new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \
|
new_missing="$$( (for f in $$(otool -L $$lib 2>/dev/null | tail -n +2 | awk '{ print $$1 }' | grep -v '$$lib'); do ls -hal $$f; done;) 2>&1 | grep 'No such' | cut -d':' -f2 | xargs basename -a)"; \
|
||||||
missing_libs="$$missing_libs $$new_missing"; \
|
missing_libs="$$missing_libs $$new_missing"; \
|
||||||
elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \
|
elif [ "$(OS)" = "${CI_MSYS_VERSION}" ]; then \
|
||||||
missing_libs="libstt.so libkenlm.so"; \
|
missing_libs="libstt.so libkenlm.so libtensorflowlite.so"; \
|
||||||
else \
|
else \
|
||||||
missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \
|
missing_libs="$$missing_libs $$($(LDD) $$lib | grep 'not found' | awk '{ print $$1 }')"; \
|
||||||
fi; \
|
fi; \
|
||||||
|
|
|
@ -27,12 +27,14 @@
|
||||||
"libraries": [
|
"libraries": [
|
||||||
"../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib",
|
"../../../tensorflow/bazel-bin/native_client/libstt.so.if.lib",
|
||||||
"../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib",
|
"../../../tensorflow/bazel-bin/native_client/libkenlm.so.if.lib",
|
||||||
|
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so.if.lib",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"libraries": [
|
"libraries": [
|
||||||
"../../../tensorflow/bazel-bin/native_client/libstt.so",
|
"../../../tensorflow/bazel-bin/native_client/libstt.so",
|
||||||
"../../../tensorflow/bazel-bin/native_client/libkenlm.so",
|
"../../../tensorflow/bazel-bin/native_client/libkenlm.so",
|
||||||
|
"../../../tensorflow/bazel-bin/tensorflow/lite/libtensorflowlite.so",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|
|
@ -14,13 +14,7 @@
|
||||||
#include "modelstate.h"
|
#include "modelstate.h"
|
||||||
|
|
||||||
#include "workspace_status.h"
|
#include "workspace_status.h"
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
|
||||||
#include "tfmodelstate.h"
|
|
||||||
#else
|
|
||||||
#include "tflitemodelstate.h"
|
#include "tflitemodelstate.h"
|
||||||
#endif // USE_TFLITE
|
|
||||||
|
|
||||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||||
|
|
||||||
#ifdef __ANDROID__
|
#ifdef __ANDROID__
|
||||||
|
@ -282,13 +276,7 @@ STT_CreateModel(const char* aModelPath,
|
||||||
return STT_ERR_NO_MODEL;
|
return STT_ERR_NO_MODEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<ModelState> model(
|
std::unique_ptr<ModelState> model(new TFLiteModelState());
|
||||||
#ifndef USE_TFLITE
|
|
||||||
new TFModelState()
|
|
||||||
#else
|
|
||||||
new TFLiteModelState()
|
|
||||||
#endif
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!model) {
|
if (!model) {
|
||||||
std::cerr << "Could not allocate model state." << std::endl;
|
std::cerr << "Could not allocate model state." << std::endl;
|
||||||
|
|
|
@ -1,263 +0,0 @@
|
||||||
#include "tfmodelstate.h"
|
|
||||||
|
|
||||||
#include "workspace_status.h"
|
|
||||||
|
|
||||||
using namespace tensorflow;
|
|
||||||
using std::vector;
|
|
||||||
|
|
||||||
TFModelState::TFModelState()
|
|
||||||
: ModelState()
|
|
||||||
, mmap_env_(nullptr)
|
|
||||||
, session_(nullptr)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
TFModelState::~TFModelState()
|
|
||||||
{
|
|
||||||
if (session_) {
|
|
||||||
Status status = session_->Close();
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int
|
|
||||||
TFModelState::init(const char* model_path)
|
|
||||||
{
|
|
||||||
int err = ModelState::init(model_path);
|
|
||||||
if (err != STT_ERR_OK) {
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status status;
|
|
||||||
SessionOptions options;
|
|
||||||
|
|
||||||
mmap_env_.reset(new MemmappedEnv(Env::Default()));
|
|
||||||
|
|
||||||
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
|
|
||||||
if (!is_mmap) {
|
|
||||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
|
||||||
} else {
|
|
||||||
status = mmap_env_->InitializeFromFile(model_path);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return STT_ERR_FAIL_INIT_MMAP;
|
|
||||||
}
|
|
||||||
|
|
||||||
options.config.mutable_graph_options()
|
|
||||||
->mutable_optimizer_options()
|
|
||||||
->set_opt_level(::OptimizerOptions::L0);
|
|
||||||
options.env = mmap_env_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
Session* session;
|
|
||||||
status = NewSession(options, &session);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return STT_ERR_FAIL_INIT_SESS;
|
|
||||||
}
|
|
||||||
session_.reset(session);
|
|
||||||
|
|
||||||
if (is_mmap) {
|
|
||||||
status = ReadBinaryProto(mmap_env_.get(),
|
|
||||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
|
||||||
&graph_def_);
|
|
||||||
} else {
|
|
||||||
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
|
|
||||||
}
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return STT_ERR_FAIL_READ_PROTOBUF;
|
|
||||||
}
|
|
||||||
|
|
||||||
status = session_->Create(graph_def_);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return STT_ERR_FAIL_CREATE_SESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<tensorflow::Tensor> version_output;
|
|
||||||
status = session_->Run({}, {
|
|
||||||
"metadata_version"
|
|
||||||
}, {}, &version_output);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Unable to fetch graph version: " << status << std::endl;
|
|
||||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
||||||
}
|
|
||||||
|
|
||||||
int graph_version = version_output[0].scalar<int>()();
|
|
||||||
if (graph_version < ds_graph_version()) {
|
|
||||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
|
||||||
<< "incompatible with minimum version supported by this client ("
|
|
||||||
<< ds_graph_version() << "). See "
|
|
||||||
<< "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
|
|
||||||
<< "for more information" << std::endl;
|
|
||||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<tensorflow::Tensor> metadata_outputs;
|
|
||||||
status = session_->Run({}, {
|
|
||||||
"metadata_sample_rate",
|
|
||||||
"metadata_feature_win_len",
|
|
||||||
"metadata_feature_win_step",
|
|
||||||
"metadata_beam_width",
|
|
||||||
"metadata_alphabet",
|
|
||||||
}, {}, &metadata_outputs);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
|
||||||
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
||||||
}
|
|
||||||
|
|
||||||
sample_rate_ = metadata_outputs[0].scalar<int>()();
|
|
||||||
int win_len_ms = metadata_outputs[1].scalar<int>()();
|
|
||||||
int win_step_ms = metadata_outputs[2].scalar<int>()();
|
|
||||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
|
||||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
|
||||||
int beam_width = metadata_outputs[3].scalar<int>()();
|
|
||||||
beam_width_ = (unsigned int)(beam_width);
|
|
||||||
|
|
||||||
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
|
||||||
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
|
||||||
if (err != 0) {
|
|
||||||
return STT_ERR_INVALID_ALPHABET;
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(sample_rate_ > 0);
|
|
||||||
assert(audio_win_len_ > 0);
|
|
||||||
assert(audio_win_step_ > 0);
|
|
||||||
assert(beam_width_ > 0);
|
|
||||||
assert(alphabet_.GetSize() > 0);
|
|
||||||
|
|
||||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
|
||||||
NodeDef node = graph_def_.node(i);
|
|
||||||
if (node.name() == "input_node") {
|
|
||||||
const auto& shape = node.attr().at("shape").shape();
|
|
||||||
n_steps_ = shape.dim(1).size();
|
|
||||||
n_context_ = (shape.dim(2).size()-1)/2;
|
|
||||||
n_features_ = shape.dim(3).size();
|
|
||||||
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
|
||||||
} else if (node.name() == "previous_state_c") {
|
|
||||||
const auto& shape = node.attr().at("shape").shape();
|
|
||||||
state_size_ = shape.dim(1).size();
|
|
||||||
} else if (node.name() == "logits_shape") {
|
|
||||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
|
||||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
|
||||||
if (final_dim_size != alphabet_.GetSize()) {
|
|
||||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
|
||||||
<< "has size " << alphabet_.GetSize()
|
|
||||||
<< ", but model has " << final_dim_size
|
|
||||||
<< " classes in its output. Make sure you're passing an alphabet "
|
|
||||||
<< "file with the same size as the one used for training."
|
|
||||||
<< std::endl;
|
|
||||||
return STT_ERR_INVALID_ALPHABET;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_context_ == -1 || n_features_ == -1) {
|
|
||||||
std::cerr << "Error: Could not infer input shape from model file. "
|
|
||||||
<< "Make sure input_node is a 4D tensor with shape "
|
|
||||||
<< "[batch_size=1, time, window_size, n_features]."
|
|
||||||
<< std::endl;
|
|
||||||
return STT_ERR_INVALID_SHAPE;
|
|
||||||
}
|
|
||||||
|
|
||||||
return STT_ERR_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor
|
|
||||||
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
|
|
||||||
{
|
|
||||||
Tensor ret(DT_FLOAT, shape);
|
|
||||||
auto ret_mapped = ret.flat<float>();
|
|
||||||
int i;
|
|
||||||
for (i = 0; i < vec.size(); ++i) {
|
|
||||||
ret_mapped(i) = vec[i];
|
|
||||||
}
|
|
||||||
for (; i < shape.num_elements(); ++i) {
|
|
||||||
ret_mapped(i) = 0.f;
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
void
|
|
||||||
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
|
|
||||||
{
|
|
||||||
auto tensor_mapped = tensor.flat<float>();
|
|
||||||
if (num_elements == -1) {
|
|
||||||
num_elements = tensor.shape().num_elements();
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_elements; ++i) {
|
|
||||||
vec.push_back(tensor_mapped(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void
|
|
||||||
TFModelState::infer(const std::vector<float>& mfcc,
|
|
||||||
unsigned int n_frames,
|
|
||||||
const std::vector<float>& previous_state_c,
|
|
||||||
const std::vector<float>& previous_state_h,
|
|
||||||
vector<float>& logits_output,
|
|
||||||
vector<float>& state_c_output,
|
|
||||||
vector<float>& state_h_output)
|
|
||||||
{
|
|
||||||
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
|
||||||
|
|
||||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
|
||||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
|
||||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
|
||||||
|
|
||||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
|
||||||
input_lengths.scalar<int>()() = n_frames;
|
|
||||||
|
|
||||||
vector<Tensor> outputs;
|
|
||||||
Status status = session_->Run(
|
|
||||||
{
|
|
||||||
{"input_node", input},
|
|
||||||
{"input_lengths", input_lengths},
|
|
||||||
{"previous_state_c", previous_state_c_t},
|
|
||||||
{"previous_state_h", previous_state_h_t}
|
|
||||||
},
|
|
||||||
{"logits", "new_state_c", "new_state_h"},
|
|
||||||
{},
|
|
||||||
&outputs);
|
|
||||||
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
|
|
||||||
|
|
||||||
state_c_output.clear();
|
|
||||||
state_c_output.reserve(state_size_);
|
|
||||||
copy_tensor_to_vector(outputs[1], state_c_output);
|
|
||||||
|
|
||||||
state_h_output.clear();
|
|
||||||
state_h_output.reserve(state_size_);
|
|
||||||
copy_tensor_to_vector(outputs[2], state_h_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
void
|
|
||||||
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
|
||||||
{
|
|
||||||
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
|
|
||||||
|
|
||||||
vector<Tensor> outputs;
|
|
||||||
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
|
||||||
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The feature computation graph is hardcoded to one audio length for now
|
|
||||||
const int n_windows = 1;
|
|
||||||
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
|
||||||
copy_tensor_to_vector(outputs[0], mfcc_output);
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
#ifndef TFMODELSTATE_H
|
|
||||||
#define TFMODELSTATE_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/core/public/session.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
|
||||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
|
||||||
|
|
||||||
#include "modelstate.h"
|
|
||||||
|
|
||||||
struct TFModelState : public ModelState
|
|
||||||
{
|
|
||||||
std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
|
|
||||||
std::unique_ptr<tensorflow::Session> session_;
|
|
||||||
tensorflow::GraphDef graph_def_;
|
|
||||||
|
|
||||||
TFModelState();
|
|
||||||
virtual ~TFModelState();
|
|
||||||
|
|
||||||
virtual int init(const char* model_path) override;
|
|
||||||
|
|
||||||
virtual void infer(const std::vector<float>& mfcc,
|
|
||||||
unsigned int n_frames,
|
|
||||||
const std::vector<float>& previous_state_c,
|
|
||||||
const std::vector<float>& previous_state_h,
|
|
||||||
std::vector<float>& logits_output,
|
|
||||||
std::vector<float>& state_c_output,
|
|
||||||
std::vector<float>& state_h_output) override;
|
|
||||||
|
|
||||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
|
||||||
std::vector<float>& mfcc_output) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // TFMODELSTATE_H
|
|
Loading…
Reference in New Issue