Add Alphabet.encode analog to .decode and better encapsulate implementation details
This commit is contained in:
parent
6e287bd340
commit
f0688ec941
@ -158,7 +158,7 @@ if __name__ == "__main__":
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
[ALPHABET.label_from_string(c) for c in label]
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
return label
|
||||
|
@ -206,7 +206,7 @@ if __name__ == "__main__":
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
[ALPHABET.label_from_string(c) for c in label]
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
return label
|
||||
|
@ -190,7 +190,7 @@ if __name__ == "__main__":
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
[ALPHABET.label_from_string(c) for c in label]
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
return label
|
||||
|
@ -213,7 +213,7 @@ if __name__ == "__main__":
|
||||
label = validate_label(label)
|
||||
if ALPHABET and label:
|
||||
try:
|
||||
[ALPHABET.label_from_string(c) for c in label]
|
||||
ALPHABET.encode(label)
|
||||
except KeyError:
|
||||
label = None
|
||||
return label
|
||||
|
@ -34,11 +34,11 @@ def sparse_tensor_value_to_texts(value, alphabet):
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [''] * sp_tuple[2][0]
|
||||
results = [[]] * sp_tuple[2][0]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]] += alphabet.string_from_label(values[i])
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return results
|
||||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model, try_loading):
|
||||
|
14
util/text.py
14
util/text.py
@ -23,10 +23,10 @@ class Alphabet(object):
|
||||
self._str_to_label[line[:-1]] = self._size
|
||||
self._size += 1
|
||||
|
||||
def string_from_label(self, label):
|
||||
def _string_from_label(self, label):
|
||||
return self._label_to_str[label]
|
||||
|
||||
def label_from_string(self, string):
|
||||
def _label_from_string(self, string):
|
||||
try:
|
||||
return self._str_to_label[string]
|
||||
except KeyError as e:
|
||||
@ -36,10 +36,16 @@ class Alphabet(object):
|
||||
'then add all these to data/alphabet.txt.'.format(string)
|
||||
).with_traceback(e.__traceback__)
|
||||
|
||||
def encode(self, string):
|
||||
res = []
|
||||
for char in string:
|
||||
res.append(self._label_from_string(char))
|
||||
return res
|
||||
|
||||
def decode(self, labels):
|
||||
res = ''
|
||||
for label in labels:
|
||||
res += self.string_from_label(label)
|
||||
res += self._string_from_label(label)
|
||||
return res
|
||||
|
||||
def size(self):
|
||||
@ -55,7 +61,7 @@ def text_to_char_array(series, alphabet):
|
||||
integers and return a numpy array representing the processed string.
|
||||
"""
|
||||
try:
|
||||
series['transcript'] = np.asarray([alphabet.label_from_string(c) for c in series['transcript']])
|
||||
series['transcript'] = np.asarray(alphabet.encode(series['transcript']))
|
||||
except KeyError as e:
|
||||
# Provide the row context (especially wav_filename) for alphabet errors
|
||||
raise ValueError(str(e), series)
|
||||
|
Loading…
x
Reference in New Issue
Block a user