Fix linter errors
X-DeepSpeech: NOBUILD
This commit is contained in:
parent
1e2eb96248
commit
3b54f54524
@ -7,7 +7,7 @@ extension-pkg-whitelist=
|
|||||||
|
|
||||||
# Add files or directories to the blacklist. They should be base names, not
|
# Add files or directories to the blacklist. They should be base names, not
|
||||||
# paths.
|
# paths.
|
||||||
ignore=examples
|
ignore=native_client/kenlm
|
||||||
|
|
||||||
# Add files or directories matching the regex patterns to the blacklist. The
|
# Add files or directories matching the regex patterns to the blacklist. The
|
||||||
# regex matches against base names, not paths.
|
# regex matches against base names, not paths.
|
||||||
|
@ -5,7 +5,8 @@ from __future__ import absolute_import, division, print_function
|
|||||||
# This script needs to be run from the root of the DeepSpeech repository
|
# This script needs to be run from the root of the DeepSpeech repository
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
|
|
||||||
|
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import shutil
|
import shutil
|
||||||
@ -14,13 +15,21 @@ from util.text import Alphabet, UTF8Alphabet
|
|||||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||||
|
|
||||||
|
|
||||||
def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8, default_alpha, default_beta):
|
def create_bundle(
|
||||||
|
alphabet_path,
|
||||||
|
lm_path,
|
||||||
|
vocab_path,
|
||||||
|
package_path,
|
||||||
|
force_utf8,
|
||||||
|
default_alpha,
|
||||||
|
default_beta,
|
||||||
|
):
|
||||||
words = set()
|
words = set()
|
||||||
vocab_looks_char_based = True
|
vocab_looks_char_based = True
|
||||||
with open(vocab_path) as fin:
|
with open(vocab_path) as fin:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
for word in line.split():
|
for word in line.split():
|
||||||
words.add(word.encode('utf-8'))
|
words.add(word.encode("utf-8"))
|
||||||
if len(word) > 1:
|
if len(word) > 1:
|
||||||
vocab_looks_char_based = False
|
vocab_looks_char_based = False
|
||||||
print("{} unique words read from vocabulary file.".format(len(words)))
|
print("{} unique words read from vocabulary file.".format(len(words)))
|
||||||
@ -30,7 +39,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8,
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if force_utf8 != None:
|
if force_utf8 != None: # pylint: disable=singleton-comparison
|
||||||
use_utf8 = force_utf8.value
|
use_utf8 = force_utf8.value
|
||||||
else:
|
else:
|
||||||
use_utf8 = vocab_looks_char_based
|
use_utf8 = vocab_looks_char_based
|
||||||
@ -54,7 +63,7 @@ def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8,
|
|||||||
scorer.fill_dictionary(list(words))
|
scorer.fill_dictionary(list(words))
|
||||||
shutil.copy(lm_path, package_path)
|
shutil.copy(lm_path, package_path)
|
||||||
scorer.save_dictionary(package_path, True) # append, not overwrite
|
scorer.save_dictionary(package_path, True) # append, not overwrite
|
||||||
print('Package created in {}'.format(package_path))
|
print("Package created in {}".format(package_path))
|
||||||
|
|
||||||
|
|
||||||
class Tristate(object):
|
class Tristate(object):
|
||||||
@ -65,8 +74,11 @@ class Tristate(object):
|
|||||||
raise ValueError("Tristate value must be True, False, or None")
|
raise ValueError("Tristate value must be True, False, or None")
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return (self.value is other.value if isinstance(other, Tristate)
|
return (
|
||||||
else self.value is other)
|
self.value is other.value
|
||||||
|
if isinstance(other, Tristate)
|
||||||
|
else self.value is other
|
||||||
|
)
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self == other
|
return not self == other
|
||||||
@ -100,8 +112,18 @@ def main():
|
|||||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
||||||
parser.add_argument("--default_alpha", type=float, required=True, help="Default value of alpha hyperparameter.")
|
parser.add_argument(
|
||||||
parser.add_argument("--default_beta", type=float, required=True, help="Default value of beta hyperparameter.")
|
"--default_alpha",
|
||||||
|
type=float,
|
||||||
|
required=True,
|
||||||
|
help="Default value of alpha hyperparameter.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--default_beta",
|
||||||
|
type=float,
|
||||||
|
required=True,
|
||||||
|
help="Default value of beta hyperparameter.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--force_utf8",
|
"--force_utf8",
|
||||||
default="",
|
default="",
|
||||||
@ -116,7 +138,15 @@ def main():
|
|||||||
else:
|
else:
|
||||||
force_utf8 = Tristate(None)
|
force_utf8 = Tristate(None)
|
||||||
|
|
||||||
create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8, args.default_alpha, args.default_beta)
|
create_bundle(
|
||||||
|
args.alphabet,
|
||||||
|
args.lm,
|
||||||
|
args.vocab,
|
||||||
|
args.package,
|
||||||
|
force_utf8,
|
||||||
|
args.default_alpha,
|
||||||
|
args.default_beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user