Merge pull request #3125 from mozilla/utf8-regressions

Fix some regressions from Alphabet refactoring (Fixes #3123)
This commit is contained in:
Reuben Morais 2020-07-04 13:21:09 +02:00 committed by GitHub
commit 66d1f167fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 9 deletions

View File

@ -57,6 +57,9 @@ Alphabet::init(const char *config_file)
if (line == " ") { if (line == " ") {
space_label_ = label; space_label_ = label;
} }
if (line.length() == 0) {
continue;
}
label_to_str_[label] = line; label_to_str_[label] = line;
str_to_label_[line] = label; str_to_label_[line] = label;
++label; ++label;
@ -187,3 +190,14 @@ Alphabet::Encode(const std::string& input) const
} }
return result; return result;
} }
std::vector<unsigned int>
UTF8Alphabet::Encode(const std::string& input) const
{
std::vector<unsigned int> result;
for (auto byte_char : input) {
std::string byte_str(1, byte_char);
result.push_back(EncodeSingle(byte_str));
}
return result;
}

View File

@ -52,7 +52,7 @@ public:
// Encode a sequence of character/output classes into a sequence of labels. // Encode a sequence of character/output classes into a sequence of labels.
// Characters are assumed to always take a single Unicode codepoint. // Characters are assumed to always take a single Unicode codepoint.
std::vector<unsigned int> Encode(const std::string& input) const; virtual std::vector<unsigned int> Encode(const std::string& input) const;
protected: protected:
size_t size_; size_t size_;
@ -77,7 +77,8 @@ public:
int init(const char*) override { int init(const char*) override {
return 0; return 0;
} }
std::vector<unsigned int> Encode(const std::string& input) const override;
}; };
#endif //ALPHABET_H #endif //ALPHABET_H

View File

@ -3,7 +3,9 @@ from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import UTF8Alphabet from .swigwrapper import UTF8Alphabet
__version__ = swigwrapper.__version__ # This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
# string encoding explicitly, here and throughout this file.
__version__ = swigwrapper.__version__.decode('utf-8')
# Hack: import error codes by matching on their names, as SWIG unfortunately # Hack: import error codes by matching on their names, as SWIG unfortunately
# does not support binding enums to Python in a scoped manner yet. # does not support binding enums to Python in a scoped manner yet.
@ -30,7 +32,7 @@ class Scorer(swigwrapper.Scorer):
assert beta is not None, 'beta parameter is required' assert beta is not None, 'beta parameter is required'
assert scorer_path, 'scorer_path parameter is required' assert scorer_path, 'scorer_path parameter is required'
err = self.init(scorer_path, alphabet) err = self.init(scorer_path.encode('utf-8'), alphabet)
if err != 0: if err != 0:
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err)) raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
@ -41,15 +43,27 @@ class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor""" """Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path): def __init__(self, config_path):
super(Alphabet, self).__init__() super(Alphabet, self).__init__()
err = self.init(config_path) err = self.init(config_path.encode('utf-8'))
if err != 0: if err != 0:
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err)) raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
def EncodeSingle(self, input):
return super(Alphabet, self).EncodeSingle(input.encode('utf-8'))
def Encode(self, input): def Encode(self, input):
"""Convert SWIG's UnsignedIntVec to a Python list""" # Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input) res = super(Alphabet, self).Encode(input.encode('utf-8'))
return [el for el in res] return [el for el in res]
def DecodeSingle(self, input):
res = super(Alphabet, self).DecodeSingle(input)
return res.decode('utf-8')
def Decode(self, input):
res = super(Alphabet, self).Decode(input)
return res.decode('utf-8')
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
alphabet, alphabet,

View File

@ -182,8 +182,9 @@ DecoderState::decode(size_t num_results) const
// score the last word of each prefix that doesn't end with space // score the last word of each prefix that doesn't end with space
if (ext_scorer_) { if (ext_scorer_) {
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) { for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i]; PathTrie* prefix = prefixes_copy[i];
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) { PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent;
if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) {
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix); std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
bool bos = ngram.size() < ext_scorer_->get_max_order(); bool bos = ngram.size() < ext_scorer_->get_max_order();

View File

@ -3,6 +3,7 @@
%{ %{
#include "ctc_beam_search_decoder.h" #include "ctc_beam_search_decoder.h"
#define SWIG_FILE_WITH_INIT #define SWIG_FILE_WITH_INIT
#define SWIG_PYTHON_STRICT_BYTE_CHAR
#include "workspace_status.h" #include "workspace_status.h"
%} %}