diff --git a/bin/import_cv2.py b/bin/import_cv2.py index 083d2176..122755a4 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -158,7 +158,7 @@ if __name__ == "__main__": label = validate_label(label) if ALPHABET and label: try: - [ALPHABET.label_from_string(c) for c in label] + ALPHABET.encode(label) except KeyError: label = None return label diff --git a/bin/import_lingua_libre.py b/bin/import_lingua_libre.py index 5b126406..8706fde6 100644 --- a/bin/import_lingua_libre.py +++ b/bin/import_lingua_libre.py @@ -206,7 +206,7 @@ if __name__ == "__main__": label = validate_label(label) if ALPHABET and label: try: - [ALPHABET.label_from_string(c) for c in label] + ALPHABET.encode(label) except KeyError: label = None return label diff --git a/bin/import_m-ailabs.py b/bin/import_m-ailabs.py index ff59ad0f..7b8ca5b5 100644 --- a/bin/import_m-ailabs.py +++ b/bin/import_m-ailabs.py @@ -190,7 +190,7 @@ if __name__ == "__main__": label = validate_label(label) if ALPHABET and label: try: - [ALPHABET.label_from_string(c) for c in label] + ALPHABET.encode(label) except KeyError: label = None return label diff --git a/bin/import_slr57.py b/bin/import_slr57.py index 8520a1b3..67ea54d0 100644 --- a/bin/import_slr57.py +++ b/bin/import_slr57.py @@ -213,7 +213,7 @@ if __name__ == "__main__": label = validate_label(label) if ALPHABET and label: try: - [ALPHABET.label_from_string(c) for c in label] + ALPHABET.encode(label) except KeyError: label = None return label diff --git a/evaluate.py b/evaluate.py index 4cda4523..e511c8b6 100755 --- a/evaluate.py +++ b/evaluate.py @@ -34,11 +34,11 @@ def sparse_tensor_value_to_texts(value, alphabet): def sparse_tuple_to_texts(sp_tuple, alphabet): indices = sp_tuple[0] values = sp_tuple[1] - results = [''] * sp_tuple[2][0] + results = [[]] * sp_tuple[2][0] for i, index in enumerate(indices): - results[index[0]] += alphabet.string_from_label(values[i]) + results[index[0]].append(values[i]) # List of strings - return results + return [alphabet.decode(res) for res in results] def evaluate(test_csvs, create_model, try_loading): diff --git a/util/text.py b/util/text.py index 7c1b1ee5..0db0bb25 100644 --- a/util/text.py +++ b/util/text.py @@ -23,10 +23,10 @@ class Alphabet(object): self._str_to_label[line[:-1]] = self._size self._size += 1 - def string_from_label(self, label): + def _string_from_label(self, label): return self._label_to_str[label] - def label_from_string(self, string): + def _label_from_string(self, string): try: return self._str_to_label[string] except KeyError as e: @@ -36,10 +36,16 @@ class Alphabet(object): 'then add all these to data/alphabet.txt.'.format(string) ).with_traceback(e.__traceback__) + def encode(self, string): + res = [] + for char in string: + res.append(self._label_from_string(char)) + return res + def decode(self, labels): res = '' for label in labels: - res += self.string_from_label(label) + res += self._string_from_label(label) return res def size(self): @@ -55,7 +61,7 @@ def text_to_char_array(series, alphabet): integers and return a numpy array representing the processed string. """ try: - series['transcript'] = np.asarray([alphabet.label_from_string(c) for c in series['transcript']]) + series['transcript'] = np.asarray(alphabet.encode(series['transcript'])) except KeyError as e: # Provide the row context (especially wav_filename) for alphabet errors raise ValueError(str(e), series)