diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index 17aa8ac1..f7af448e 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -202,7 +202,7 @@ Refer to the :ref:`usage instructions ` for information on running a Exporting a model for TFLite ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the ``--export_tflite`` flags. If you already have a trained model, you can re-export it for TFLite by running ``DeepSpeech.py`` again and specifying the same ``checkpoint_dir`` that you used for training, as well as passing ``--export_tflite --export_dir /model/export/destination``. +If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the ``--export_tflite`` flags. If you already have a trained model, you can re-export it for TFLite by running ``DeepSpeech.py`` again and specifying the same ``checkpoint_dir`` that you used for training, as well as passing ``--export_tflite --export_dir /model/export/destination``. If you changed the alphabet you also need to add the ``--alphabet_config_path my-new-language-alphabet.txt`` flag. Making a mmap-able model for inference ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index 1f0a8dbe..34b97a3b 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -57,6 +57,9 @@ Alphabet::init(const char *config_file) if (line == " ") { space_label_ = label; } + if (line.length() == 0) { + continue; + } label_to_str_[label] = line; str_to_label_[line] = label; ++label; @@ -187,3 +190,14 @@ Alphabet::Encode(const std::string& input) const } return result; } + +std::vector +UTF8Alphabet::Encode(const std::string& input) const +{ + std::vector result; + for (auto byte_char : input) { + std::string byte_str(1, byte_char); + result.push_back(EncodeSingle(byte_str)); + } + return result; +} diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 45fc444e..1303cf89 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -52,7 +52,7 @@ public: // Encode a sequence of character/output classes into a sequence of labels. // Characters are assumed to always take a single Unicode codepoint. - std::vector Encode(const std::string& input) const; + virtual std::vector Encode(const std::string& input) const; protected: size_t size_; @@ -77,7 +77,8 @@ public: int init(const char*) override { return 0; } + + std::vector Encode(const std::string& input) const override; }; - #endif //ALPHABET_H diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index ee5645d4..e66633b6 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -3,7 +3,9 @@ from __future__ import absolute_import, division, print_function from . import swigwrapper # pylint: disable=import-self from .swigwrapper import UTF8Alphabet -__version__ = swigwrapper.__version__ +# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle +# string encoding explicitly, here and throughout this file. +__version__ = swigwrapper.__version__.decode('utf-8') # Hack: import error codes by matching on their names, as SWIG unfortunately # does not support binding enums to Python in a scoped manner yet. @@ -30,7 +32,7 @@ class Scorer(swigwrapper.Scorer): assert beta is not None, 'beta parameter is required' assert scorer_path, 'scorer_path parameter is required' - err = self.init(scorer_path, alphabet) + err = self.init(scorer_path.encode('utf-8'), alphabet) if err != 0: raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err)) @@ -41,15 +43,27 @@ class Alphabet(swigwrapper.Alphabet): """Convenience wrapper for Alphabet which calls init in the constructor""" def __init__(self, config_path): super(Alphabet, self).__init__() - err = self.init(config_path) + err = self.init(config_path.encode('utf-8')) if err != 0: raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err)) + def EncodeSingle(self, input): + return super(Alphabet, self).EncodeSingle(input.encode('utf-8')) + def Encode(self, input): - """Convert SWIG's UnsignedIntVec to a Python list""" - res = super(Alphabet, self).Encode(input) + # Convert SWIG's UnsignedIntVec to a Python list + res = super(Alphabet, self).Encode(input.encode('utf-8')) return [el for el in res] + def DecodeSingle(self, input): + res = super(Alphabet, self).DecodeSingle(input) + return res.decode('utf-8') + + def Decode(self, input): + res = super(Alphabet, self).Decode(input) + return res.decode('utf-8') + + def ctc_beam_search_decoder(probs_seq, alphabet, diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 5fd45931..8afd416e 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -182,8 +182,9 @@ DecoderState::decode(size_t num_results) const // score the last word of each prefix that doesn't end with space if (ext_scorer_) { for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) { - auto prefix = prefixes_copy[i]; - if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) { + PathTrie* prefix = prefixes_copy[i]; + PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent; + if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) { float score = 0.0; std::vector ngram = ext_scorer_->make_ngram(prefix); bool bos = ngram.size() < ext_scorer_->get_max_order(); diff --git a/native_client/ctcdecode/swigwrapper.i b/native_client/ctcdecode/swigwrapper.i index ffe23c3a..dbe67c68 100644 --- a/native_client/ctcdecode/swigwrapper.i +++ b/native_client/ctcdecode/swigwrapper.i @@ -3,6 +3,7 @@ %{ #include "ctc_beam_search_decoder.h" #define SWIG_FILE_WITH_INIT +#define SWIG_PYTHON_STRICT_BYTE_CHAR #include "workspace_status.h" %}