Merge branch 'master' into r0.8

This commit is contained in:
Reuben Morais 2020-07-04 13:21:38 +02:00
commit 94782b94c5
6 changed files with 41 additions and 10 deletions

View File

@ -202,7 +202,7 @@ Refer to the :ref:`usage instructions <usage-docs>` for information on running a
Exporting a model for TFLite
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the ``--export_tflite`` flags. If you already have a trained model, you can re-export it for TFLite by running ``DeepSpeech.py`` again and specifying the same ``checkpoint_dir`` that you used for training, as well as passing ``--export_tflite --export_dir /model/export/destination``.
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the ``--export_tflite`` flags. If you already have a trained model, you can re-export it for TFLite by running ``DeepSpeech.py`` again and specifying the same ``checkpoint_dir`` that you used for training, as well as passing ``--export_tflite --export_dir /model/export/destination``. If you changed the alphabet you also need to add the ``--alphabet_config_path my-new-language-alphabet.txt`` flag.
Making a mmap-able model for inference
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -57,6 +57,9 @@ Alphabet::init(const char *config_file)
if (line == " ") {
space_label_ = label;
}
if (line.length() == 0) {
continue;
}
label_to_str_[label] = line;
str_to_label_[line] = label;
++label;
@ -187,3 +190,14 @@ Alphabet::Encode(const std::string& input) const
}
return result;
}
std::vector<unsigned int>
UTF8Alphabet::Encode(const std::string& input) const
{
std::vector<unsigned int> result;
for (auto byte_char : input) {
std::string byte_str(1, byte_char);
result.push_back(EncodeSingle(byte_str));
}
return result;
}

View File

@ -52,7 +52,7 @@ public:
// Encode a sequence of character/output classes into a sequence of labels.
// Characters are assumed to always take a single Unicode codepoint.
std::vector<unsigned int> Encode(const std::string& input) const;
virtual std::vector<unsigned int> Encode(const std::string& input) const;
protected:
size_t size_;
@ -77,7 +77,8 @@ public:
int init(const char*) override {
return 0;
}
std::vector<unsigned int> Encode(const std::string& input) const override;
};
#endif //ALPHABET_H

View File

@ -3,7 +3,9 @@ from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import UTF8Alphabet
__version__ = swigwrapper.__version__
# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
# string encoding explicitly, here and throughout this file.
__version__ = swigwrapper.__version__.decode('utf-8')
# Hack: import error codes by matching on their names, as SWIG unfortunately
# does not support binding enums to Python in a scoped manner yet.
@ -30,7 +32,7 @@ class Scorer(swigwrapper.Scorer):
assert beta is not None, 'beta parameter is required'
assert scorer_path, 'scorer_path parameter is required'
err = self.init(scorer_path, alphabet)
err = self.init(scorer_path.encode('utf-8'), alphabet)
if err != 0:
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
@ -41,15 +43,27 @@ class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path):
super(Alphabet, self).__init__()
err = self.init(config_path)
err = self.init(config_path.encode('utf-8'))
if err != 0:
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
def EncodeSingle(self, input):
return super(Alphabet, self).EncodeSingle(input.encode('utf-8'))
def Encode(self, input):
"""Convert SWIG's UnsignedIntVec to a Python list"""
res = super(Alphabet, self).Encode(input)
# Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input.encode('utf-8'))
return [el for el in res]
def DecodeSingle(self, input):
res = super(Alphabet, self).DecodeSingle(input)
return res.decode('utf-8')
def Decode(self, input):
res = super(Alphabet, self).Decode(input)
return res.decode('utf-8')
def ctc_beam_search_decoder(probs_seq,
alphabet,

View File

@ -182,8 +182,9 @@ DecoderState::decode(size_t num_results) const
// score the last word of each prefix that doesn't end with space
if (ext_scorer_) {
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i];
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
PathTrie* prefix = prefixes_copy[i];
PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent;
if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
bool bos = ngram.size() < ext_scorer_->get_max_order();

View File

@ -3,6 +3,7 @@
%{
#include "ctc_beam_search_decoder.h"
#define SWIG_FILE_WITH_INIT
#define SWIG_PYTHON_STRICT_BYTE_CHAR
#include "workspace_status.h"
%}