Enforce CTC decoder version check

Fix #2710
This commit is contained in:
Alexandre Lissy 2020-01-31 23:43:19 +01:00
parent 6af68efa00
commit ff401732a3
11 changed files with 55 additions and 6 deletions

View File

@ -27,6 +27,7 @@ from util.config import Config, initialize_globals
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
# Graph Creation

View File

@ -22,6 +22,7 @@ from util.evaluate_tools import calculate_report
from util.feeding import create_dataset
from util.flags import create_flags, FLAGS
from util.logging import log_error, log_progress, create_progressbar
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet):

View File

@ -18,9 +18,12 @@ set -ex
tf_git_rev=$(git describe --long --tags)
echo "STABLE_TF_GIT_VERSION ${tf_git_rev}"
pushd native_client
# use this trick to be able to use the script from anywhere
pushd $(dirname "$0")
ds_git_rev=$(git describe --long --tags)
echo "STABLE_DS_GIT_VERSION ${ds_git_rev}"
ds_version=$(cat ../VERSION)
echo "STABLE_DS_VERSION ${ds_version}"
ds_graph_version=$(cat ../GRAPH_VERSION)
echo "STABLE_DS_GRAPH_VERSION ${ds_graph_version}"
popd

View File

@ -1,4 +1,4 @@
.PHONY: bindings clean
.PHONY: bindings clean workspace_status.cc
include ../definitions.mk
@ -23,15 +23,22 @@ clean-keep-common:
clean: clean-keep-common
rm -f common.a
rm workspace_status.cc
rm -fr bazel-out/
bindings: clean-keep-common
workspace_status.cc:
mkdir -p bazel-out/ && \
../bazel_workspace_status_cmd.sh > bazel-out/stable-status.txt && \
../gen_workspace_status.sh > $@
bindings: clean-keep-common workspace_status.cc
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
find temp_build -type f -name "*.o" -delete
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
rm -rf temp_build
bindings-debug: clean-keep-common
bindings-debug: clean-keep-common workspace_status.cc
pip install --quiet $(PYTHON_PACKAGES) wheel==0.33.6 setuptools==39.1.0
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS) -DDEBUG" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --debug --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
$(GENERATE_DEBUG_SYMS)

View File

@ -1,7 +1,8 @@
from __future__ import absolute_import, division, print_function
from . import swigwrapper
from . import swigwrapper # pylint: disable=import-self
__version__ = swigwrapper.__version__
class Scorer(swigwrapper.Scorer):
"""Wrapper for Scorer.

View File

@ -60,7 +60,8 @@ decoder_module = Extension(
'ctc_beam_search_decoder.cpp',
'scorer.cpp',
'path_trie.cpp',
'decoder_utils.cpp'],
'decoder_utils.cpp',
'workspace_status.cc'],
swig_opts=['-c++', '-extranative'],
language='c++',
include_dirs=INCLUDES + [numpy_include],

View File

@ -4,6 +4,7 @@
#include "ctc_beam_search_decoder.h"
#define SWIG_FILE_WITH_INIT
#define SWIG_PYTHON_STRICT_BYTE_CHAR
#include "workspace_status.h"
%}
%include "pyabc.i"
@ -27,6 +28,9 @@ import_array();
%include "scorer.h"
%include "ctc_beam_search_decoder.h"
%constant const char* __version__ = ds_version();
%constant const char* __git_version__ = ds_git_version();
%template(IntVector) std::vector<int>;
%template(OutputVector) std::vector<Output>;
%template(OutputVectorVector) std::vector<std::vector<Output>>;

View File

@ -3,6 +3,7 @@
set -x
tf_git_version=$(grep "STABLE_TF_GIT_VERSION" "bazel-out/stable-status.txt" | cut -d' ' -f2)
ds_version=$(grep "STABLE_DS_VERSION" "bazel-out/stable-status.txt" | cut -d' ' -f2)
ds_git_version=$(grep "STABLE_DS_GIT_VERSION" "bazel-out/stable-status.txt" | cut -d' ' -f2)
ds_graph_version=$(grep "STABLE_DS_GRAPH_VERSION" "bazel-out/stable-status.txt" | cut -d' ' -f2)
@ -10,6 +11,9 @@ cat <<EOF
const char *tf_local_git_version() {
return "${tf_git_version}";
}
const char *ds_version() {
return "${ds_version}";
}
const char *ds_git_version() {
return "${ds_git_version}";
}

View File

@ -2,6 +2,7 @@
#define WORKSPACE_STATUS_H
const char *tf_local_git_version();
const char *ds_version();
const char *ds_git_version();
const int ds_graph_version();

View File

@ -7,6 +7,7 @@ six
pyxdg
attrdict
absl-py
semver
# Requirements for building native_client files
setuptools

View File

@ -7,3 +7,28 @@ def secs_to_hours(secs):
hours, remainder = divmod(secs, 3600)
minutes, seconds = divmod(remainder, 60)
return '%d:%02d:%02d' % (hours, minutes, seconds)
# pylint: disable=import-outside-toplevel
def check_ctcdecoder_version():
import sys
import os
import semver
ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip()
try:
from ds_ctcdecoder import __version__ as decoder_version
except ImportError as e:
if e.msg.find('__version__') > 0:
print("DeepSpeech version ({ds_version}) requires CTC decoder to expose __version__. Please upgrade the ds_ctcdecoder package to version {ds_version}".format(ds_version=ds_version_s))
sys.exit(1)
raise e
decoder_version_s = decoder_version.decode()
rv = semver.compare(ds_version_s, decoder_version_s)
if rv != 0:
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
sys.exit(1)
return rv