Fix linter errors

X-DeepSpeech: NOBUILD
This commit is contained in:
Reuben Morais 2020-01-21 12:49:51 +01:00
parent 1e2eb96248
commit 3b54f54524
2 changed files with 48 additions and 18 deletions

View File

@ -7,7 +7,7 @@ extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not # Add files or directories to the blacklist. They should be base names, not
# paths. # paths.
ignore=examples ignore=native_client/kenlm
# Add files or directories matching the regex patterns to the blacklist. The # Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths. # regex matches against base names, not paths.

View File

@ -5,7 +5,8 @@ from __future__ import absolute_import, division, print_function
# This script needs to be run from the root of the DeepSpeech repository # This script needs to be run from the root of the DeepSpeech repository
import os import os
import sys import sys
sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
import argparse import argparse
import shutil import shutil
@ -14,13 +15,21 @@ from util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8, default_alpha, default_beta): def create_bundle(
alphabet_path,
lm_path,
vocab_path,
package_path,
force_utf8,
default_alpha,
default_beta,
):
words = set() words = set()
vocab_looks_char_based = True vocab_looks_char_based = True
with open(vocab_path) as fin: with open(vocab_path) as fin:
for line in fin: for line in fin:
for word in line.split(): for word in line.split():
words.add(word.encode('utf-8')) words.add(word.encode("utf-8"))
if len(word) > 1: if len(word) > 1:
vocab_looks_char_based = False vocab_looks_char_based = False
print("{} unique words read from vocabulary file.".format(len(words))) print("{} unique words read from vocabulary file.".format(len(words)))
@ -30,7 +39,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8,
) )
) )
if force_utf8 != None: if force_utf8 != None: # pylint: disable=singleton-comparison
use_utf8 = force_utf8.value use_utf8 = force_utf8.value
else: else:
use_utf8 = vocab_looks_char_based use_utf8 = vocab_looks_char_based
@ -54,7 +63,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8,
scorer.fill_dictionary(list(words)) scorer.fill_dictionary(list(words))
shutil.copy(lm_path, package_path) shutil.copy(lm_path, package_path)
scorer.save_dictionary(package_path, True) # append, not overwrite scorer.save_dictionary(package_path, True) # append, not overwrite
print('Package created in {}'.format(package_path)) print("Package created in {}".format(package_path))
class Tristate(object): class Tristate(object):
@ -65,8 +74,11 @@ class Tristate(object):
raise ValueError("Tristate value must be True, False, or None") raise ValueError("Tristate value must be True, False, or None")
def __eq__(self, other): def __eq__(self, other):
return (self.value is other.value if isinstance(other, Tristate) return (
else self.value is other) self.value is other.value
if isinstance(other, Tristate)
else self.value is other
)
def __ne__(self, other): def __ne__(self, other):
return not self == other return not self == other
@ -100,8 +112,18 @@ def main():
help="Path of vocabulary file. Must contain words separated by whitespace.", help="Path of vocabulary file. Must contain words separated by whitespace.",
) )
parser.add_argument("--package", required=True, help="Path to save scorer package.") parser.add_argument("--package", required=True, help="Path to save scorer package.")
parser.add_argument("--default_alpha", type=float, required=True, help="Default value of alpha hyperparameter.") parser.add_argument(
parser.add_argument("--default_beta", type=float, required=True, help="Default value of beta hyperparameter.") "--default_alpha",
type=float,
required=True,
help="Default value of alpha hyperparameter.",
)
parser.add_argument(
"--default_beta",
type=float,
required=True,
help="Default value of beta hyperparameter.",
)
parser.add_argument( parser.add_argument(
"--force_utf8", "--force_utf8",
default="", default="",
@ -116,7 +138,15 @@ def main():
else: else:
force_utf8 = Tristate(None) force_utf8 = Tristate(None)
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8, args.default_alpha, args.default_beta) create_bundle(
args.alphabet,
args.lm,
args.vocab,
args.package,
force_utf8,
args.default_alpha,
args.default_beta,
)
if __name__ == "__main__": if __name__ == "__main__":