Fix binding of UTF8Alphabet class in decoder package

This commit is contained in:
Reuben Morais 2020-10-05 09:46:38 +02:00
parent 421f44cf73
commit 2fd11dd74a

View File

@ -1,7 +1,6 @@
from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import UTF8Alphabet
# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
# string encoding explicitly, here and throughout this file.
@ -89,6 +88,56 @@ class Alphabet(swigwrapper.Alphabet):
return res.decode('utf-8')
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self):
super(UTF8Alphabet, self).__init__()
err = self.init(b'')
if err != 0:
raise ValueError('UTF8Alphabet initialization failed with error code 0x{:X}'.format(err))
def CanEncodeSingle(self, input):
'''
Returns true if the single character/output class has a corresponding label
in the alphabet.
'''
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode('utf-8'))
def CanEncode(self, input):
'''
Returns true if the entire string can be encoded into labels in this
alphabet.
'''
return super(UTF8Alphabet, self).CanEncode(input.encode('utf-8'))
def EncodeSingle(self, input):
'''
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
'''
return super(UTF8Alphabet, self).EncodeSingle(input.encode('utf-8'))
def Encode(self, input):
'''
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test.
'''
# Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode('utf-8'))
return [el for el in res]
def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode('utf-8')
def Decode(self, input):
'''Decode a sequence of labels into a string.'''
res = super(UTF8Alphabet, self).Decode(input)
return res.decode('utf-8')
def ctc_beam_search_decoder(probs_seq,
alphabet,