Add Alphabet.encode analog to .decode and better encapsulate implementation details

This commit is contained in:
Reuben Morais 2019-10-22 14:47:53 +02:00
parent 6e287bd340
commit f0688ec941
6 changed files with 17 additions and 11 deletions

View File

@ -158,7 +158,7 @@ if __name__ == "__main__":
label = validate_label(label)
if ALPHABET and label:
try:
[ALPHABET.label_from_string(c) for c in label]
ALPHABET.encode(label)
except KeyError:
label = None
return label

View File

@ -206,7 +206,7 @@ if __name__ == "__main__":
label = validate_label(label)
if ALPHABET and label:
try:
[ALPHABET.label_from_string(c) for c in label]
ALPHABET.encode(label)
except KeyError:
label = None
return label

View File

@ -190,7 +190,7 @@ if __name__ == "__main__":
label = validate_label(label)
if ALPHABET and label:
try:
[ALPHABET.label_from_string(c) for c in label]
ALPHABET.encode(label)
except KeyError:
label = None
return label

View File

@ -213,7 +213,7 @@ if __name__ == "__main__":
label = validate_label(label)
if ALPHABET and label:
try:
[ALPHABET.label_from_string(c) for c in label]
ALPHABET.encode(label)
except KeyError:
label = None
return label

View File

@ -34,11 +34,11 @@ def sparse_tensor_value_to_texts(value, alphabet):
def sparse_tuple_to_texts(sp_tuple, alphabet):
indices = sp_tuple[0]
values = sp_tuple[1]
results = [''] * sp_tuple[2][0]
results = [[]] * sp_tuple[2][0]
for i, index in enumerate(indices):
results[index[0]] += alphabet.string_from_label(values[i])
results[index[0]].append(values[i])
# List of strings
return results
return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model, try_loading):

View File

@ -23,10 +23,10 @@ class Alphabet(object):
self._str_to_label[line[:-1]] = self._size
self._size += 1
def string_from_label(self, label):
def _string_from_label(self, label):
return self._label_to_str[label]
def label_from_string(self, string):
def _label_from_string(self, string):
try:
return self._str_to_label[string]
except KeyError as e:
@ -36,10 +36,16 @@ class Alphabet(object):
'then add all these to data/alphabet.txt.'.format(string)
).with_traceback(e.__traceback__)
def encode(self, string):
res = []
for char in string:
res.append(self._label_from_string(char))
return res
def decode(self, labels):
res = ''
for label in labels:
res += self.string_from_label(label)
res += self._string_from_label(label)
return res
def size(self):
@ -55,7 +61,7 @@ def text_to_char_array(series, alphabet):
integers and return a numpy array representing the processed string.
"""
try:
series['transcript'] = np.asarray([alphabet.label_from_string(c) for c in series['transcript']])
series['transcript'] = np.asarray(alphabet.encode(series['transcript']))
except KeyError as e:
# Provide the row context (especially wav_filename) for alphabet errors
raise ValueError(str(e), series)