diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index 34b97a3b..9abc65a5 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -137,6 +137,24 @@ Alphabet::Deserialize(const char* buffer, const int buffer_size) return 0; } +bool +Alphabet::CanEncodeSingle(const std::string& input) const +{ + auto it = str_to_label_.find(input); + return it != str_to_label_.end(); +} + +bool +Alphabet::CanEncode(const std::string& input) const +{ + for (auto cp : split_into_codepoints(input)) { + if (!CanEncodeSingle(cp)) { + return false; + } + } + return true; +} + std::string Alphabet::DecodeSingle(unsigned int label) const { @@ -191,6 +209,18 @@ Alphabet::Encode(const std::string& input) const return result; } +bool +UTF8Alphabet::CanEncodeSingle(const std::string& input) const +{ + return true; +} + +bool +UTF8Alphabet::CanEncode(const std::string& input) const +{ + return true; +} + std::vector UTF8Alphabet::Encode(const std::string& input) const { diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 1303cf89..f402cc0d 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -37,10 +37,19 @@ public: return space_label_; } + // Returns true if the single character/output class has a corresponding label + // in the alphabet. + virtual bool CanEncodeSingle(const std::string& string) const; + + // Returns true if the entire string can be encoded into labels in this + // alphabet. + virtual bool CanEncode(const std::string& string) const; + // Decode a single label into a string. std::string DecodeSingle(unsigned int label) const; - // Encode a single character/output class into a label. + // Encode a single character/output class into a label. Character must be in + // the alphabet, this method will assert that. Use `CanEncodeSingle` to test. unsigned int EncodeSingle(const std::string& string) const; // Decode a sequence of labels into a string. @@ -52,6 +61,8 @@ public: // Encode a sequence of character/output classes into a sequence of labels. // Characters are assumed to always take a single Unicode codepoint. + // Characters must be in the alphabet, this method will assert that. Use + // `CanEncode` and `CanEncodeSingle` to test. virtual std::vector Encode(const std::string& input) const; protected: @@ -78,6 +89,8 @@ public: return 0; } + bool CanEncodeSingle(const std::string& string) const override; + bool CanEncode(const std::string& string) const override; std::vector Encode(const std::string& input) const override; }; diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index e66633b6..18f402a7 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -47,6 +47,12 @@ class Alphabet(swigwrapper.Alphabet): if err != 0: raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err)) + def CanEncodeSingle(self, input): + return super(Alphabet, self).CanEncodeSingle(input.encode('utf-8')) + + def CanEncode(self, input): + return super(Alphabet, self).CanEncode(input.encode('utf-8')) + def EncodeSingle(self, input): return super(Alphabet, self).EncodeSingle(input.encode('utf-8'))