diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index 94e03b15..80edc51d 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, division, print_function from . import swigwrapper # pylint: disable=import-self -from .swigwrapper import UTF8Alphabet # This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle # string encoding explicitly, here and throughout this file. @@ -89,6 +88,56 @@ class Alphabet(swigwrapper.Alphabet): return res.decode('utf-8') +class UTF8Alphabet(swigwrapper.UTF8Alphabet): + """Convenience wrapper for Alphabet which calls init in the constructor""" + def __init__(self): + super(UTF8Alphabet, self).__init__() + err = self.init(b'') + if err != 0: + raise ValueError('UTF8Alphabet initialization failed with error code 0x{:X}'.format(err)) + + def CanEncodeSingle(self, input): + ''' + Returns true if the single character/output class has a corresponding label + in the alphabet. + ''' + return super(UTF8Alphabet, self).CanEncodeSingle(input.encode('utf-8')) + + def CanEncode(self, input): + ''' + Returns true if the entire string can be encoded into labels in this + alphabet. + ''' + return super(UTF8Alphabet, self).CanEncode(input.encode('utf-8')) + + def EncodeSingle(self, input): + ''' + Encode a single character/output class into a label. Character must be in + the alphabet, this method will assert that. Use `CanEncodeSingle` to test. + ''' + return super(UTF8Alphabet, self).EncodeSingle(input.encode('utf-8')) + + def Encode(self, input): + ''' + Encode a sequence of character/output classes into a sequence of labels. + Characters are assumed to always take a single Unicode codepoint. + Characters must be in the alphabet, this method will assert that. Use + `CanEncode` and `CanEncodeSingle` to test. + ''' + # Convert SWIG's UnsignedIntVec to a Python list + res = super(UTF8Alphabet, self).Encode(input.encode('utf-8')) + return [el for el in res] + + def DecodeSingle(self, input): + res = super(UTF8Alphabet, self).DecodeSingle(input) + return res.decode('utf-8') + + def Decode(self, input): + '''Decode a sequence of labels into a string.''' + res = super(UTF8Alphabet, self).Decode(input) + return res.decode('utf-8') + + def ctc_beam_search_decoder(probs_seq, alphabet,