Merge pull request #2040 from mozilla/pylint-hook
Add linter config for Python and CI integration on PRs
This commit is contained in:
commit
68c17611c6
2
.cardboardlint.yml
Normal file
2
.cardboardlint.yml
Normal file
@ -0,0 +1,2 @@
|
||||
linters:
|
||||
- pylint:
|
581
.pylintrc
Normal file
581
.pylintrc
Normal file
@ -0,0 +1,581 @@
|
||||
[MASTER]
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code.
|
||||
extension-pkg-whitelist=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use.
|
||||
jobs=1
|
||||
|
||||
# Control the amount of potential inferred values when inferring a single
|
||||
# object. This can help the performance when dealing with large functions or
|
||||
# complex, nested conditions.
|
||||
limit-inference-results=100
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||
# user-friendly hints instead of false-positive error messages.
|
||||
suggestion-mode=yes
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
|
||||
confidence=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once). You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||
# --disable=W".
|
||||
disable=missing-docstring,
|
||||
line-too-long,
|
||||
wrong-import-order,
|
||||
ungrouped-imports,
|
||||
wrong-import-position,
|
||||
import-error,
|
||||
no-name-in-module,
|
||||
no-member,
|
||||
unsubscriptable-object,
|
||||
print-statement,
|
||||
parameter-unpacking,
|
||||
unpacking-in-except,
|
||||
old-raise-syntax,
|
||||
backtick,
|
||||
long-suffix,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
import-star-module-level,
|
||||
non-ascii-bytes-literal,
|
||||
raw-checker-failed,
|
||||
bad-inline-option,
|
||||
locally-disabled,
|
||||
file-ignored,
|
||||
suppressed-message,
|
||||
useless-suppression,
|
||||
deprecated-pragma,
|
||||
use-symbolic-message-instead,
|
||||
useless-object-inheritance,
|
||||
too-few-public-methods,
|
||||
too-many-branches,
|
||||
too-many-arguments,
|
||||
too-many-locals,
|
||||
too-many-statements,
|
||||
apply-builtin,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
cmp-builtin,
|
||||
coerce-builtin,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
long-builtin,
|
||||
raw_input-builtin,
|
||||
reduce-builtin,
|
||||
standarderror-builtin,
|
||||
unicode-builtin,
|
||||
xrange-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
getslice-method,
|
||||
setslice-method,
|
||||
no-absolute-import,
|
||||
old-division,
|
||||
dict-iter-method,
|
||||
dict-view-method,
|
||||
next-method-called,
|
||||
metaclass-assignment,
|
||||
indexing-exception,
|
||||
raising-string,
|
||||
reload-builtin,
|
||||
oct-method,
|
||||
hex-method,
|
||||
nonzero-method,
|
||||
cmp-method,
|
||||
input-builtin,
|
||||
round-builtin,
|
||||
intern-builtin,
|
||||
unichr-builtin,
|
||||
map-builtin-not-iterating,
|
||||
zip-builtin-not-iterating,
|
||||
range-builtin-not-iterating,
|
||||
filter-builtin-not-iterating,
|
||||
using-cmp-argument,
|
||||
eq-without-hash,
|
||||
div-method,
|
||||
idiv-method,
|
||||
rdiv-method,
|
||||
exception-message-attribute,
|
||||
invalid-str-codec,
|
||||
sys-max-int,
|
||||
bad-python3-import,
|
||||
deprecated-string-function,
|
||||
deprecated-str-translate-call,
|
||||
deprecated-itertools-function,
|
||||
deprecated-types-field,
|
||||
next-method-defined,
|
||||
dict-items-not-iterating,
|
||||
dict-keys-not-iterating,
|
||||
dict-values-not-iterating,
|
||||
deprecated-operator-function,
|
||||
deprecated-urllib-function,
|
||||
xreadlines-attribute,
|
||||
deprecated-sys-function,
|
||||
exception-escape,
|
||||
comprehension-escape
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details.
|
||||
#msg-template=
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, json
|
||||
# and msvs (visual studio). You can also give a reporter class, e.g.
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages.
|
||||
reports=no
|
||||
|
||||
# Activate the evaluation score.
|
||||
score=yes
|
||||
|
||||
|
||||
[REFACTORING]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
# Complete name of functions that never returns. When checking for
|
||||
# inconsistent-return-statements if a never returning function is called then
|
||||
# it will be considered as an explicit return statement and no message will be
|
||||
# printed.
|
||||
never-returning-functions=sys.exit
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Format style used to check logging format string. `old` means using %
|
||||
# formatting, while `new` is for `{}` formatting.
|
||||
logging-format-style=old
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format.
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Limits count of emitted suggestions for spelling mistakes.
|
||||
max-spelling-suggestions=4
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package..
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# Tells whether to warn about missing members when the owner of the attribute
|
||||
# is inferred to be None.
|
||||
ignore-none=yes
|
||||
|
||||
# This flag controls whether pylint should warn about no-member and similar
|
||||
# checks whenever an opaque object is returned when inferring. The inference
|
||||
# can return multiple potential results while evaluating a Python object, but
|
||||
# some branches might not be evaluated, which results in partial inference. In
|
||||
# that case, it might be useful to still emit no-member and other checks for
|
||||
# the rest of the inferred objects.
|
||||
ignore-on-opaque-inference=yes
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# Show a hint with possible names when a member name was not found. The aspect
|
||||
# of finding the hint is based on edit distance.
|
||||
missing-member-hint=yes
|
||||
|
||||
# The minimum edit distance a name should have in order to be considered a
|
||||
# similar match for a missing member name.
|
||||
missing-member-hint-distance=1
|
||||
|
||||
# The total number of similar names that should be taken in consideration when
|
||||
# showing a hint for a missing member.
|
||||
missing-member-max-choices=1
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid defining new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# Tells whether unused global variables should be treated as a violation.
|
||||
allow-global-unused-variables=yes
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,
|
||||
_cb
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expected to
|
||||
# not be used).
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore.
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=100
|
||||
|
||||
# Maximum number of lines in a module.
|
||||
max-module-lines=1000
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=trailing-comma,
|
||||
dict-separator
|
||||
|
||||
# Allow the body of a class to be on the same line as the declaration if body
|
||||
# contains single statement.
|
||||
single-line-class-stmt=no
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Naming style matching correct argument names.
|
||||
argument-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct argument names. Overrides argument-
|
||||
# naming-style.
|
||||
argument-rgx=[a-z_][a-z0-9_]{0,30}$
|
||||
|
||||
# Naming style matching correct attribute names.
|
||||
attr-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||
# style.
|
||||
#attr-rgx=
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma.
|
||||
bad-names=
|
||||
|
||||
# Naming style matching correct class attribute names.
|
||||
class-attribute-naming-style=any
|
||||
|
||||
# Regular expression matching correct class attribute names. Overrides class-
|
||||
# attribute-naming-style.
|
||||
#class-attribute-rgx=
|
||||
|
||||
# Naming style matching correct class names.
|
||||
class-naming-style=PascalCase
|
||||
|
||||
# Regular expression matching correct class names. Overrides class-naming-
|
||||
# style.
|
||||
#class-rgx=
|
||||
|
||||
# Naming style matching correct constant names.
|
||||
const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct constant names. Overrides const-naming-
|
||||
# style.
|
||||
#const-rgx=
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
# Naming style matching correct function names.
|
||||
function-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct function names. Overrides function-
|
||||
# naming-style.
|
||||
#function-rgx=
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma.
|
||||
good-names=i,
|
||||
j,
|
||||
k,
|
||||
x,
|
||||
ex,
|
||||
Run,
|
||||
_
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name.
|
||||
include-naming-hint=no
|
||||
|
||||
# Naming style matching correct inline iteration names.
|
||||
inlinevar-naming-style=any
|
||||
|
||||
# Regular expression matching correct inline iteration names. Overrides
|
||||
# inlinevar-naming-style.
|
||||
#inlinevar-rgx=
|
||||
|
||||
# Naming style matching correct method names.
|
||||
method-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct method names. Overrides method-naming-
|
||||
# style.
|
||||
#method-rgx=
|
||||
|
||||
# Naming style matching correct module names.
|
||||
module-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct module names. Overrides module-naming-
|
||||
# style.
|
||||
#module-rgx=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
# These decorators are taken in consideration only for invalid-name.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Naming style matching correct variable names.
|
||||
variable-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct variable names. Overrides variable-
|
||||
# naming-style.
|
||||
variable-rgx=[a-z_][a-z0-9_]{0,30}$
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether the implicit-str-concat-in-sequence should
|
||||
# generate a warning on implicit string concatenation in sequences defined over
|
||||
# several lines.
|
||||
check-str-concat-over-line-jumps=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Allow wildcard imports from modules that define __all__.
|
||||
allow-wildcard-with-all=no
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma.
|
||||
deprecated-modules=optparse,tkinter.tix
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled).
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled).
|
||||
import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled).
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=cls
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method.
|
||||
max-args=5
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Maximum number of boolean expressions in an if statement.
|
||||
max-bool-expr=5
|
||||
|
||||
# Maximum number of branch for function / method body.
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of locals for function / method body.
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=7
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of return / yield for function / method body.
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of statements in function / method body.
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "BaseException, Exception".
|
||||
overgeneral-exceptions=BaseException,
|
||||
Exception
|
17
.travis.yml
Normal file
17
.travis.yml
Normal file
@ -0,0 +1,17 @@
|
||||
language: python
|
||||
|
||||
cache: pip
|
||||
before_cache:
|
||||
- rm ~/.cache/pip/log/debug.log
|
||||
|
||||
python:
|
||||
- "3.6"
|
||||
|
||||
install:
|
||||
- pip install --upgrade cardboardlint pylint
|
||||
|
||||
script:
|
||||
# Run cardboardlinter, in case of pull requests
|
||||
- if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
|
||||
cardboardlinter --refspec $TRAVIS_BRANCH -n auto;
|
||||
fi
|
@ -5,17 +5,17 @@ from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
log_level_index = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[log_level_index] if log_level_index > 0 and log_level_index < len(sys.argv) else '3'
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
|
||||
import time
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
from util.config import Config, initialize_globals
|
||||
@ -49,7 +49,7 @@ def create_overlapping_windows(batch_x):
|
||||
# convolution returns patches of the input tensor as is, and we can create
|
||||
# overlapping windows over the MFCCs.
|
||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32)
|
||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
||||
|
||||
# Create overlapping windows
|
||||
batch_x = tf.nn.conv1d(batch_x, eye_filter, stride=1, padding='SAME')
|
||||
@ -172,7 +172,7 @@ def create_model(batch_x, seq_length, dropout, reuse=False, previous_state=None,
|
||||
# Conveniently, this loss function is implemented in TensorFlow.
|
||||
# Thus, we can simply make use of this implementation to define our loss.
|
||||
|
||||
def calculate_mean_edit_distance_and_loss(iterator, tower, dropout, reuse):
|
||||
def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||
r'''
|
||||
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
|
||||
Next to total and average loss it returns the mean edit distance,
|
||||
@ -246,10 +246,10 @@ def get_tower_results(iterator, optimizer, dropout_rates):
|
||||
device = Config.available_devices[i]
|
||||
with tf.device(device):
|
||||
# Create a scope for all operations of tower i
|
||||
with tf.name_scope('tower_%d' % i) as scope:
|
||||
with tf.name_scope('tower_%d' % i):
|
||||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
|
||||
# batch along with the original batch's labels (Y) of this tower
|
||||
avg_loss = calculate_mean_edit_distance_and_loss(iterator, i, dropout_rates, reuse=i>0)
|
||||
avg_loss = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
|
||||
|
||||
# Allow for variables to be re-used by the next tower
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
@ -460,9 +460,9 @@ def train():
|
||||
def __init__(self):
|
||||
progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f')
|
||||
|
||||
def __call__(self, progress, data):
|
||||
def __call__(self, progress, data, **kwargs):
|
||||
data['mean_loss'] = total_loss / step_count if step_count else 0.0
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data)
|
||||
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
pbar = progressbar.ProgressBar(widgets=['Epoch {}'.format(epoch),
|
||||
@ -547,7 +547,7 @@ def train():
|
||||
|
||||
|
||||
def test():
|
||||
evaluate.evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
@ -570,12 +570,12 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = previous_state_c = previous_state_h = None
|
||||
else:
|
||||
if not tflite:
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
else:
|
||||
if tflite:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
else:
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
@ -620,28 +620,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
if not tflite:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
return (
|
||||
{
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'input_samples': input_samples,
|
||||
},
|
||||
{
|
||||
'outputs': logits,
|
||||
'initialize_state': initialize_state,
|
||||
'mfccs': mfccs,
|
||||
},
|
||||
layers
|
||||
)
|
||||
else:
|
||||
if tflite:
|
||||
logits = tf.identity(logits, name='logits')
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
@ -656,17 +635,32 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
if FLAGS.use_seq_length:
|
||||
inputs.update({'input_lengths': seq_length})
|
||||
|
||||
return (
|
||||
inputs,
|
||||
{
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
},
|
||||
layers
|
||||
)
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
else:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'initialize_state': initialize_state,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
def file_relative_read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
@ -680,11 +674,9 @@ def export():
|
||||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
input_names = ",".join(tensor.op.name for tensor in inputs.values())
|
||||
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation)]
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
@ -828,6 +820,6 @@ def main(_):
|
||||
tf.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
if __name__ == '__main__' :
|
||||
if __name__ == '__main__':
|
||||
create_flags()
|
||||
tf.app.run(main)
|
||||
|
53
README.md
53
README.md
@ -51,6 +51,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
|
||||
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
|
||||
- [Making a mmap-able model for inference](#making-a-mmap-able-model-for-inference)
|
||||
- [Continuing training from a release model](#continuing-training-from-a-release-model)
|
||||
- [Contribution guidelines](#contribution-guidelines)
|
||||
- [Contact/Getting Help](#contactgetting-help)
|
||||
|
||||
## Prerequisites
|
||||
@ -372,6 +373,58 @@ python3 DeepSpeech.py --n_hidden 2048 --checkpoint_dir path/to/checkpoint/folder
|
||||
|
||||
Note: the released models were trained with `--n_hidden 2048`, so you need to use that same value when initializing from the release models.
|
||||
|
||||
## Contribution guidelines
|
||||
|
||||
This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
|
||||
|
||||
Before making a Pull Request, check your changes for basic mistakes and style problems by using a linter. We have cardboardlinter setup in this repository, so for example, if you've made some changes and would like to run the linter on just the changed code, you can use the follow command:
|
||||
|
||||
```bash
|
||||
pip install pylint cardboardlint
|
||||
cardboardlinter --refspec master
|
||||
```
|
||||
|
||||
This will compare the code against master and run the linter on all the changes. We plan to introduce more linter checks (e.g. for C++) in the future. To run it automatically as a git pre-commit hook, do the following:
|
||||
|
||||
```bash
|
||||
cat <<\EOF > .git/hooks/pre-commit
|
||||
#!/bin/bash
|
||||
if [ ! -x "$(command -v cardboardlinter)" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# First, stash index and work dir, keeping only the
|
||||
# to-be-committed changes in the working directory.
|
||||
echo "Stashing working tree changes..." 1>&2
|
||||
old_stash=$(git rev-parse -q --verify refs/stash)
|
||||
git stash save -q --keep-index
|
||||
new_stash=$(git rev-parse -q --verify refs/stash)
|
||||
|
||||
# If there were no changes (e.g., `--amend` or `--allow-empty`)
|
||||
# then nothing was stashed, and we should skip everything,
|
||||
# including the tests themselves. (Presumably the tests passed
|
||||
# on the previous commit, so there is no need to re-run them.)
|
||||
if [ "$old_stash" = "$new_stash" ]; then
|
||||
echo "No changes, skipping lint." 1>&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Run tests
|
||||
cardboardlinter --refspec HEAD^ -n auto
|
||||
status=$?
|
||||
|
||||
# Restore changes
|
||||
echo "Restoring working tree changes..." 1>&2
|
||||
git reset --hard -q && git stash apply --index -q && git stash drop -q
|
||||
|
||||
# Exit with status from test-run: nonzero prevents commit
|
||||
exit $status
|
||||
EOF
|
||||
chmod +x .git/hooks/pre-commit
|
||||
```
|
||||
|
||||
This will run the linters on just the changes made in your commit.
|
||||
|
||||
## Contact/Getting Help
|
||||
|
||||
There are several ways to contact us or to get help:
|
||||
|
28
evaluate.py
28
evaluate.py
@ -4,13 +4,16 @@ from __future__ import absolute_import, division, print_function
|
||||
|
||||
import itertools
|
||||
import json
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import numpy as np
|
||||
import progressbar
|
||||
import tensorflow as tf
|
||||
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from multiprocessing import cpu_count
|
||||
from six.moves import zip, range
|
||||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.evaluate_tools import calculate_report
|
||||
from util.feeding import create_dataset
|
||||
@ -27,13 +30,12 @@ def sparse_tensor_value_to_texts(value, alphabet):
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
|
||||
def sparse_tuple_to_texts(tuple, alphabet):
|
||||
indices = tuple[0]
|
||||
values = tuple[1]
|
||||
results = [''] * tuple[2][0]
|
||||
for i in range(len(indices)):
|
||||
index = indices[i][0]
|
||||
results[index] += alphabet.string_from_label(values[i])
|
||||
def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
indices = sp_tuple[0]
|
||||
values = sp_tuple[1]
|
||||
results = [''] * sp_tuple[2][0]
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]] += alphabet.string_from_label(values[i])
|
||||
# List of strings
|
||||
return results
|
||||
|
||||
@ -63,7 +65,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
inputs=logits,
|
||||
sequence_length=batch_x_len)
|
||||
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
tf.train.get_or_create_global_step()
|
||||
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
# Create a saver using variables from the above newly created graph
|
||||
@ -109,7 +111,7 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||
# Get number of accessible CPU cores for this process
|
||||
try:
|
||||
num_processes = cpu_count()
|
||||
except:
|
||||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
print('Decoding predictions...')
|
||||
@ -151,12 +153,12 @@ def main(_):
|
||||
'the --test_files flag.')
|
||||
exit(1)
|
||||
|
||||
from DeepSpeech import create_model, try_loading
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=lambda x: float(x))
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,55 +1,56 @@
|
||||
import csv
|
||||
import sys
|
||||
import glob
|
||||
|
||||
"""
|
||||
Usage: $ python3 check_characters.py "INFILE"
|
||||
e.g. $ python3 check_characters.py -csv /home/data/french.csv
|
||||
e.g. $ python3 check_characters.py -csv ../train.csv,../test.csv
|
||||
e.g. $ python3 check_characters.py -alpha -csv ../train.csv
|
||||
e.g. $ python3 check_characters.py -csv ../train.csv,../test.csv
|
||||
e.g. $ python3 check_characters.py -alpha -csv ../train.csv
|
||||
|
||||
Point this script to your transcripts, and it returns
|
||||
to the terminal the unique set of characters in those
|
||||
Point this script to your transcripts, and it returns
|
||||
to the terminal the unique set of characters in those
|
||||
files (combined).
|
||||
|
||||
These files are assumed to be csv, with the transcript being the third field.
|
||||
|
||||
The script simply reads all the text from all the files,
|
||||
storing a set of unique characters that were seen
|
||||
The script simply reads all the text from all the files,
|
||||
storing a set of unique characters that were seen
|
||||
along the way.
|
||||
"""
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
|
||||
parser.add_argument("-alpha", "--alphabet-format",help="Bool. Print in format for alphabet.txt",action="store_true")
|
||||
parser.set_defaults(alphabet_format=False)
|
||||
args = parser.parse_args()
|
||||
inFiles = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
parser.add_argument("-csv", "--csv-files", help="Str. Filenames as a comma separated list", required=True)
|
||||
parser.add_argument("-alpha", "--alphabet-format", help="Bool. Print in format for alphabet.txt", action="store_true")
|
||||
args = parser.parse_args()
|
||||
in_files = [os.path.abspath(i) for i in args.csv_files.split(",")]
|
||||
|
||||
print("### Reading in the following transcript files: ###")
|
||||
print("### {} ###".format(inFiles))
|
||||
print("### Reading in the following transcript files: ###")
|
||||
print("### {} ###".format(in_files))
|
||||
|
||||
allText = set()
|
||||
for inFile in (inFiles):
|
||||
with open(inFile, "r") as csvFile:
|
||||
reader = csv.reader(csvFile)
|
||||
try:
|
||||
next(reader, None) # skip the file header (i.e. "transcript")
|
||||
for row in reader:
|
||||
allText |= set(str(row[2]))
|
||||
except IndexError as ie:
|
||||
print("Your input file",inFile,"is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript")
|
||||
sys.exit(-1)
|
||||
finally:
|
||||
csvFile.close()
|
||||
all_text = set()
|
||||
for in_file in in_files:
|
||||
with open(in_file, "r") as csv_file:
|
||||
reader = csv.reader(csv_file)
|
||||
try:
|
||||
next(reader, None) # skip the file header (i.e. "transcript")
|
||||
for row in reader:
|
||||
all_text |= set(str(row[2]))
|
||||
except IndexError:
|
||||
print("Your input file", in_file, "is not formatted properly. Check if there are 3 columns with the 3rd containing the transcript")
|
||||
sys.exit(-1)
|
||||
finally:
|
||||
csv_file.close()
|
||||
|
||||
print("### The following unique characters were found in your transcripts: ###")
|
||||
if args.alphabet_format:
|
||||
for char in list(allText):
|
||||
print(char)
|
||||
print("### ^^^ You can copy-paste these into data/alphabet.txt ###")
|
||||
else:
|
||||
print(list(allText))
|
||||
print("### The following unique characters were found in your transcripts: ###")
|
||||
if args.alphabet_format:
|
||||
for char in list(all_text):
|
||||
print(char)
|
||||
print("### ^^^ You can copy-paste these into data/alphabet.txt ###")
|
||||
else:
|
||||
print(list(all_text))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -4,11 +4,12 @@ import os
|
||||
import tensorflow as tf
|
||||
|
||||
from attrdict import AttrDict
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.gpu import get_available_gpus
|
||||
from util.logging import log_error
|
||||
from util.text import Alphabet
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
@ -21,7 +22,7 @@ class ConfigSingleton:
|
||||
return ConfigSingleton._config[name]
|
||||
|
||||
|
||||
Config = ConfigSingleton()
|
||||
Config = ConfigSingleton() # pylint: disable=invalid-name
|
||||
|
||||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
@ -33,7 +34,7 @@ def initialize_globals():
|
||||
c.available_devices = get_available_gpus()
|
||||
|
||||
# If there is no GPU available, we fall back to CPU based operation
|
||||
if 0 == len(c.available_devices):
|
||||
if not c.available_devices:
|
||||
c.available_devices = [c.cpu_device]
|
||||
|
||||
# Set default dropout rates
|
||||
@ -45,15 +46,15 @@ def initialize_globals():
|
||||
FLAGS.dropout_rate6 = FLAGS.dropout_rate
|
||||
|
||||
# Set default checkpoint dir
|
||||
if len(FLAGS.checkpoint_dir) == 0:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech','checkpoints'))
|
||||
if not FLAGS.checkpoint_dir:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
|
||||
|
||||
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load = 'auto'
|
||||
|
||||
# Set default summary dir
|
||||
if len(FLAGS.summary_dir) == 0:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech','summaries'))
|
||||
if not FLAGS.summary_dir:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join('deepspeech', 'summaries'))
|
||||
|
||||
# Standard session configuration that'll be used for all new sessions.
|
||||
c.session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
|
||||
@ -103,4 +104,4 @@ def initialize_globals():
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
exit(1)
|
||||
|
||||
ConfigSingleton._config = c
|
||||
ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
@ -2,8 +2,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from attrdict import AttrDict
|
||||
from multiprocessing.dummy import Pool
|
||||
|
||||
from attrdict import AttrDict
|
||||
|
||||
from util.text import wer_cer_batch, levenshtein
|
||||
|
||||
def pmap(fun, iterable):
|
||||
|
@ -1,13 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
import tensorflow as tf
|
||||
|
||||
from functools import partial
|
||||
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
|
||||
|
||||
from util.config import Config
|
||||
from util.text import text_to_char_array
|
||||
|
||||
@ -18,7 +21,7 @@ def read_csvs(csv_files):
|
||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||
#FIXME: not cross-platform
|
||||
csv_dir = os.path.dirname(os.path.abspath(csv))
|
||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
|
||||
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1))) # pylint: disable=cell-var-from-loop
|
||||
if source_data is None:
|
||||
source_data = file
|
||||
else:
|
||||
|
126
util/flags.py
126
util/flags.py
@ -10,110 +10,110 @@ def create_flags():
|
||||
# Importer
|
||||
# ========
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
|
||||
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
|
||||
f = tf.app.flags
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
|
||||
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
|
||||
tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model')
|
||||
f.DEFINE_string('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
f.DEFINE_string('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
f.DEFINE_string('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
|
||||
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
tf.app.flags.DEFINE_integer ('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
|
||||
f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
|
||||
|
||||
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
|
||||
tf.app.flags.DEFINE_float ('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
|
||||
f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
||||
f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
|
||||
f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
|
||||
f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
|
||||
f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
|
||||
|
||||
tf.app.flags.DEFINE_float ('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
|
||||
f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
|
||||
|
||||
# Adam optimizer (http://arxiv.org/abs/1412.6980) parameters
|
||||
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
|
||||
|
||||
tf.app.flags.DEFINE_float ('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
|
||||
tf.app.flags.DEFINE_float ('learning_rate', 0.001, 'learning rate of Adam optimizer')
|
||||
f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
|
||||
f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
|
||||
f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
|
||||
f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer')
|
||||
|
||||
# Batch sizes
|
||||
|
||||
tf.app.flags.DEFINE_integer ('train_batch_size', 1, 'number of elements in a training batch')
|
||||
tf.app.flags.DEFINE_integer ('dev_batch_size', 1, 'number of elements in a validation batch')
|
||||
tf.app.flags.DEFINE_integer ('test_batch_size', 1, 'number of elements in a test batch')
|
||||
f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch')
|
||||
f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch')
|
||||
f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch')
|
||||
|
||||
tf.app.flags.DEFINE_integer ('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
||||
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
||||
|
||||
# Performance (UNSUPPORTED)
|
||||
tf.app.flags.DEFINE_integer ('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
||||
tf.app.flags.DEFINE_integer ('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
||||
# Performance(UNSUPPORTED)
|
||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
||||
|
||||
# Sample limits
|
||||
|
||||
tf.app.flags.DEFINE_integer ('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
||||
tf.app.flags.DEFINE_integer ('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
||||
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
||||
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
||||
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
||||
|
||||
# Checkpointing
|
||||
|
||||
tf.app.flags.DEFINE_string ('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
tf.app.flags.DEFINE_integer ('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
tf.app.flags.DEFINE_integer ('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
tf.app.flags.DEFINE_string ('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
||||
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
||||
|
||||
# Exporting
|
||||
|
||||
tf.app.flags.DEFINE_string ('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
tf.app.flags.DEFINE_integer ('export_version', 1, 'version number of the exported model')
|
||||
tf.app.flags.DEFINE_boolean ('remove_export', False, 'whether to remove old exported models')
|
||||
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
||||
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
tf.app.flags.DEFINE_string ('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
f.DEFINE_integer('export_version', 1, 'version number of the exported model')
|
||||
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
|
||||
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
f.DEFINE_boolean('use_seq_length', True, 'have sequence_length in the exported graph(will make tfcompile unhappy)')
|
||||
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
|
||||
# Reporting
|
||||
|
||||
tf.app.flags.DEFINE_integer ('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
tf.app.flags.DEFINE_boolean ('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
tf.app.flags.DEFINE_integer ('report_count', 10, 'number of phrases with lowest WER (best matching) to print out during a WER report')
|
||||
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report')
|
||||
|
||||
tf.app.flags.DEFINE_string ('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
|
||||
# Geometry
|
||||
|
||||
tf.app.flags.DEFINE_integer ('n_hidden', 2048, 'layer width to use when initialising layers')
|
||||
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers')
|
||||
|
||||
# Initialization
|
||||
|
||||
tf.app.flags.DEFINE_integer ('random_seed', 4568, 'default random seed that is used to initialize variables')
|
||||
f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables')
|
||||
|
||||
# Early Stopping
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
||||
tf.app.flags.DEFINE_integer ('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
tf.app.flags.DEFINE_float ('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_float ('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
||||
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
||||
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
|
||||
# Decoder
|
||||
|
||||
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
||||
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
||||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||
tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||
f.DEFINE_string('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
||||
f.DEFINE_string('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
||||
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
|
||||
# Inference mode
|
||||
|
||||
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|
||||
|
||||
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|
||||
|
@ -27,4 +27,4 @@ def log_warn(message):
|
||||
|
||||
def log_error(message):
|
||||
if FLAGS.log_level <= 3:
|
||||
prefix_print('E ', message)
|
||||
prefix_print('E ', message)
|
||||
|
@ -2,12 +2,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function, absolute_import, division
|
||||
|
||||
import argparse
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import errno
|
||||
import stat
|
||||
|
||||
import six.moves.urllib as urllib
|
||||
|
||||
from pkg_resources import parse_version
|
||||
@ -23,9 +25,9 @@ TASKCLUSTER_SCHEME = os.getenv('TASKCLUSTER_SCHEME', DEFAULT_SCHEMES['deepspeech
|
||||
def get_tc_url(arch_string, artifact_name='native_client.tar.xz', branch_name='master'):
|
||||
assert arch_string is not None
|
||||
assert artifact_name is not None
|
||||
assert len(artifact_name) > 0
|
||||
assert artifact_name
|
||||
assert branch_name is not None
|
||||
assert len(branch_name) > 0
|
||||
assert branch_name
|
||||
|
||||
return TASKCLUSTER_SCHEME % { 'arch_string': arch_string, 'artifact_name': artifact_name, 'branch_name': branch_name}
|
||||
|
||||
@ -66,9 +68,7 @@ def maybe_download_tc_bin(**kwargs):
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Tooling to ease downloading of components from TaskCluster.')
|
||||
parser.add_argument('--target', required=False,
|
||||
help='Where to put the native client binary files')
|
||||
@ -151,3 +151,6 @@ if __name__ == '__main__':
|
||||
|
||||
if '.tar.' in args.artifact:
|
||||
subprocess.check_call(['tar', 'xvf', os.path.join(args.target, args.artifact), '-C', args.target])
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
25
util/text.py
25
util/text.py
@ -1,9 +1,9 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import numpy as np
|
||||
import re
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
from six.moves import range
|
||||
|
||||
@ -33,7 +33,6 @@ class Alphabet(object):
|
||||
raise KeyError(
|
||||
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
|
||||
).with_traceback(e.__traceback__)
|
||||
sys.exit()
|
||||
|
||||
def decode(self, labels):
|
||||
res = ''
|
||||
@ -94,18 +93,18 @@ def wer_cer_batch(originals, results):
|
||||
# version 1.0. This software is distributed without any warranty. For more
|
||||
# information, see <http://creativecommons.org/publicdomain/zero/1.0>
|
||||
|
||||
def levenshtein(a,b):
|
||||
def levenshtein(a, b):
|
||||
"Calculates the Levenshtein distance between a and b."
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
# Make sure n <= m, to use O(min(n,m)) space
|
||||
a,b = b,a
|
||||
n,m = m,n
|
||||
a, b = b, a
|
||||
n, m = m, n
|
||||
|
||||
current = list(range(n+1))
|
||||
for i in range(1,m+1):
|
||||
for i in range(1, m+1):
|
||||
previous, current = current, [i]+[0]*n
|
||||
for j in range(1,n+1):
|
||||
for j in range(1, n+1):
|
||||
add, delete = previous[j]+1, current[j-1]+1
|
||||
change = previous[j-1]
|
||||
if a[j-1] != b[i-1]:
|
||||
@ -118,14 +117,7 @@ def levenshtein(a,b):
|
||||
# or None if it's invalid.
|
||||
def validate_label(label):
|
||||
# For now we can only handle [a-z ']
|
||||
if "(" in label or \
|
||||
"<" in label or \
|
||||
"[" in label or \
|
||||
"]" in label or \
|
||||
"&" in label or \
|
||||
"*" in label or \
|
||||
"{" in label or \
|
||||
re.search(r"[0-9]", label) != None:
|
||||
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
|
||||
return None
|
||||
|
||||
label = label.replace("-", "")
|
||||
@ -138,4 +130,3 @@ def validate_label(label):
|
||||
label = label.lower()
|
||||
|
||||
return label if label else None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user