Merge pull request #3125 from mozilla/utf8-regressions
Fix some regressions from Alphabet refactoring (Fixes #3123)
This commit is contained in:
commit
66d1f167fc
@ -57,6 +57,9 @@ Alphabet::init(const char *config_file)
|
|||||||
if (line == " ") {
|
if (line == " ") {
|
||||||
space_label_ = label;
|
space_label_ = label;
|
||||||
}
|
}
|
||||||
|
if (line.length() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
label_to_str_[label] = line;
|
label_to_str_[label] = line;
|
||||||
str_to_label_[line] = label;
|
str_to_label_[line] = label;
|
||||||
++label;
|
++label;
|
||||||
@ -187,3 +190,14 @@ Alphabet::Encode(const std::string& input) const
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned int>
|
||||||
|
UTF8Alphabet::Encode(const std::string& input) const
|
||||||
|
{
|
||||||
|
std::vector<unsigned int> result;
|
||||||
|
for (auto byte_char : input) {
|
||||||
|
std::string byte_str(1, byte_char);
|
||||||
|
result.push_back(EncodeSingle(byte_str));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
@ -52,7 +52,7 @@ public:
|
|||||||
|
|
||||||
// Encode a sequence of character/output classes into a sequence of labels.
|
// Encode a sequence of character/output classes into a sequence of labels.
|
||||||
// Characters are assumed to always take a single Unicode codepoint.
|
// Characters are assumed to always take a single Unicode codepoint.
|
||||||
std::vector<unsigned int> Encode(const std::string& input) const;
|
virtual std::vector<unsigned int> Encode(const std::string& input) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
@ -77,7 +77,8 @@ public:
|
|||||||
int init(const char*) override {
|
int init(const char*) override {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned int> Encode(const std::string& input) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#endif //ALPHABET_H
|
#endif //ALPHABET_H
|
||||||
|
@ -3,7 +3,9 @@ from __future__ import absolute_import, division, print_function
|
|||||||
from . import swigwrapper # pylint: disable=import-self
|
from . import swigwrapper # pylint: disable=import-self
|
||||||
from .swigwrapper import UTF8Alphabet
|
from .swigwrapper import UTF8Alphabet
|
||||||
|
|
||||||
__version__ = swigwrapper.__version__
|
# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
|
||||||
|
# string encoding explicitly, here and throughout this file.
|
||||||
|
__version__ = swigwrapper.__version__.decode('utf-8')
|
||||||
|
|
||||||
# Hack: import error codes by matching on their names, as SWIG unfortunately
|
# Hack: import error codes by matching on their names, as SWIG unfortunately
|
||||||
# does not support binding enums to Python in a scoped manner yet.
|
# does not support binding enums to Python in a scoped manner yet.
|
||||||
@ -30,7 +32,7 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
assert beta is not None, 'beta parameter is required'
|
assert beta is not None, 'beta parameter is required'
|
||||||
assert scorer_path, 'scorer_path parameter is required'
|
assert scorer_path, 'scorer_path parameter is required'
|
||||||
|
|
||||||
err = self.init(scorer_path, alphabet)
|
err = self.init(scorer_path.encode('utf-8'), alphabet)
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
|
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
|
||||||
|
|
||||||
@ -41,15 +43,27 @@ class Alphabet(swigwrapper.Alphabet):
|
|||||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path):
|
||||||
super(Alphabet, self).__init__()
|
super(Alphabet, self).__init__()
|
||||||
err = self.init(config_path)
|
err = self.init(config_path.encode('utf-8'))
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
|
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
|
||||||
|
|
||||||
|
def EncodeSingle(self, input):
|
||||||
|
return super(Alphabet, self).EncodeSingle(input.encode('utf-8'))
|
||||||
|
|
||||||
def Encode(self, input):
|
def Encode(self, input):
|
||||||
"""Convert SWIG's UnsignedIntVec to a Python list"""
|
# Convert SWIG's UnsignedIntVec to a Python list
|
||||||
res = super(Alphabet, self).Encode(input)
|
res = super(Alphabet, self).Encode(input.encode('utf-8'))
|
||||||
return [el for el in res]
|
return [el for el in res]
|
||||||
|
|
||||||
|
def DecodeSingle(self, input):
|
||||||
|
res = super(Alphabet, self).DecodeSingle(input)
|
||||||
|
return res.decode('utf-8')
|
||||||
|
|
||||||
|
def Decode(self, input):
|
||||||
|
res = super(Alphabet, self).Decode(input)
|
||||||
|
return res.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
alphabet,
|
alphabet,
|
||||||
|
@ -182,8 +182,9 @@ DecoderState::decode(size_t num_results) const
|
|||||||
// score the last word of each prefix that doesn't end with space
|
// score the last word of each prefix that doesn't end with space
|
||||||
if (ext_scorer_) {
|
if (ext_scorer_) {
|
||||||
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
||||||
auto prefix = prefixes_copy[i];
|
PathTrie* prefix = prefixes_copy[i];
|
||||||
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
|
PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent;
|
||||||
|
if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) {
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
||||||
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
%{
|
%{
|
||||||
#include "ctc_beam_search_decoder.h"
|
#include "ctc_beam_search_decoder.h"
|
||||||
#define SWIG_FILE_WITH_INIT
|
#define SWIG_FILE_WITH_INIT
|
||||||
|
#define SWIG_PYTHON_STRICT_BYTE_CHAR
|
||||||
#include "workspace_status.h"
|
#include "workspace_status.h"
|
||||||
%}
|
%}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user