Merge pull request #2454 from mozilla/encapsulate-alphabet
Add Alphabet.encode analog to .decode and better encapsulate implementation details
This commit is contained in:
commit
ca401b0813
@ -158,7 +158,7 @@ if __name__ == "__main__":
|
|||||||
label = validate_label(label)
|
label = validate_label(label)
|
||||||
if ALPHABET and label:
|
if ALPHABET and label:
|
||||||
try:
|
try:
|
||||||
[ALPHABET.label_from_string(c) for c in label]
|
ALPHABET.encode(label)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
label = None
|
label = None
|
||||||
return label
|
return label
|
||||||
|
|||||||
@ -206,7 +206,7 @@ if __name__ == "__main__":
|
|||||||
label = validate_label(label)
|
label = validate_label(label)
|
||||||
if ALPHABET and label:
|
if ALPHABET and label:
|
||||||
try:
|
try:
|
||||||
[ALPHABET.label_from_string(c) for c in label]
|
ALPHABET.encode(label)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
label = None
|
label = None
|
||||||
return label
|
return label
|
||||||
|
|||||||
@ -190,7 +190,7 @@ if __name__ == "__main__":
|
|||||||
label = validate_label(label)
|
label = validate_label(label)
|
||||||
if ALPHABET and label:
|
if ALPHABET and label:
|
||||||
try:
|
try:
|
||||||
[ALPHABET.label_from_string(c) for c in label]
|
ALPHABET.encode(label)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
label = None
|
label = None
|
||||||
return label
|
return label
|
||||||
|
|||||||
@ -213,7 +213,7 @@ if __name__ == "__main__":
|
|||||||
label = validate_label(label)
|
label = validate_label(label)
|
||||||
if ALPHABET and label:
|
if ALPHABET and label:
|
||||||
try:
|
try:
|
||||||
[ALPHABET.label_from_string(c) for c in label]
|
ALPHABET.encode(label)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
label = None
|
label = None
|
||||||
return label
|
return label
|
||||||
|
|||||||
@ -34,11 +34,11 @@ def sparse_tensor_value_to_texts(value, alphabet):
|
|||||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||||
indices = sp_tuple[0]
|
indices = sp_tuple[0]
|
||||||
values = sp_tuple[1]
|
values = sp_tuple[1]
|
||||||
results = [''] * sp_tuple[2][0]
|
results = [[]] * sp_tuple[2][0]
|
||||||
for i, index in enumerate(indices):
|
for i, index in enumerate(indices):
|
||||||
results[index[0]] += alphabet.string_from_label(values[i])
|
results[index[0]].append(values[i])
|
||||||
# List of strings
|
# List of strings
|
||||||
return results
|
return [alphabet.decode(res) for res in results]
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_csvs, create_model, try_loading):
|
def evaluate(test_csvs, create_model, try_loading):
|
||||||
|
|||||||
14
util/text.py
14
util/text.py
@ -23,10 +23,10 @@ class Alphabet(object):
|
|||||||
self._str_to_label[line[:-1]] = self._size
|
self._str_to_label[line[:-1]] = self._size
|
||||||
self._size += 1
|
self._size += 1
|
||||||
|
|
||||||
def string_from_label(self, label):
|
def _string_from_label(self, label):
|
||||||
return self._label_to_str[label]
|
return self._label_to_str[label]
|
||||||
|
|
||||||
def label_from_string(self, string):
|
def _label_from_string(self, string):
|
||||||
try:
|
try:
|
||||||
return self._str_to_label[string]
|
return self._str_to_label[string]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
@ -36,10 +36,16 @@ class Alphabet(object):
|
|||||||
'then add all these to data/alphabet.txt.'.format(string)
|
'then add all these to data/alphabet.txt.'.format(string)
|
||||||
).with_traceback(e.__traceback__)
|
).with_traceback(e.__traceback__)
|
||||||
|
|
||||||
|
def encode(self, string):
|
||||||
|
res = []
|
||||||
|
for char in string:
|
||||||
|
res.append(self._label_from_string(char))
|
||||||
|
return res
|
||||||
|
|
||||||
def decode(self, labels):
|
def decode(self, labels):
|
||||||
res = ''
|
res = ''
|
||||||
for label in labels:
|
for label in labels:
|
||||||
res += self.string_from_label(label)
|
res += self._string_from_label(label)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
@ -55,7 +61,7 @@ def text_to_char_array(series, alphabet):
|
|||||||
integers and return a numpy array representing the processed string.
|
integers and return a numpy array representing the processed string.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
series['transcript'] = np.asarray([alphabet.label_from_string(c) for c in series['transcript']])
|
series['transcript'] = np.asarray(alphabet.encode(series['transcript']))
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
# Provide the row context (especially wav_filename) for alphabet errors
|
# Provide the row context (especially wav_filename) for alphabet errors
|
||||||
raise ValueError(str(e), series)
|
raise ValueError(str(e), series)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user