Merge pull request #1908 from JRMeyer/config-logic
This commit is contained in:
commit
c19faeb5d0
@ -14,7 +14,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--scorer "" \
|
--scorer "" \
|
||||||
--augment dropout \
|
--augment dropout \
|
||||||
|
@ -14,7 +14,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||||
|
@ -20,7 +20,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||||
|
@ -17,7 +17,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--feature_cache '/tmp/ldc93s1_cache' \
|
--feature_cache '/tmp/ldc93s1_cache' \
|
||||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||||
|
@ -17,7 +17,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||||
|
@ -23,7 +23,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||||
|
@ -23,7 +23,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_sdb} ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_sdb} ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
|
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
|
||||||
--dev_files ${ldc93s1_sdb} ${ldc93s1_csv} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_sdb} ${ldc93s1_csv} --dev_batch_size 1 \
|
||||||
|
@ -14,7 +14,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false --early_stop false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false --early_stop false \
|
||||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||||
@ -23,7 +24,7 @@ python -u train.py --show_progressbar false --early_stop false \
|
|||||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||||
--scorer_path 'data/smoke_test/pruned_lm.scorer'
|
--scorer_path 'data/smoke_test/pruned_lm.scorer'
|
||||||
|
|
||||||
python -u train.py \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
--n_hidden 100 \
|
--n_hidden 100 \
|
||||||
--checkpoint_dir '/tmp/ckpt' \
|
--checkpoint_dir '/tmp/ckpt' \
|
||||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||||
|
@ -16,7 +16,8 @@ fi;
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false \
|
||||||
--n_hidden 100 \
|
--n_hidden 100 \
|
||||||
--checkpoint_dir '/tmp/ckpt' \
|
--checkpoint_dir '/tmp/ckpt' \
|
||||||
--export_dir '/tmp/train_tflite' \
|
--export_dir '/tmp/train_tflite' \
|
||||||
@ -26,7 +27,8 @@ python -u train.py --show_progressbar false \
|
|||||||
|
|
||||||
mkdir /tmp/train_tflite/en-us
|
mkdir /tmp/train_tflite/en-us
|
||||||
|
|
||||||
python -u train.py --show_progressbar false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false \
|
||||||
--n_hidden 100 \
|
--n_hidden 100 \
|
||||||
--checkpoint_dir '/tmp/ckpt' \
|
--checkpoint_dir '/tmp/ckpt' \
|
||||||
--export_dir '/tmp/train_tflite/en-us' \
|
--export_dir '/tmp/train_tflite/en-us' \
|
||||||
|
28
bin/run-ldc93s1.py
Executable file
28
bin/run-ldc93s1.py
Executable file
@ -0,0 +1,28 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import os
|
||||||
|
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
||||||
|
from coqui_stt_training.util.config import initialize_globals_from_args
|
||||||
|
from coqui_stt_training.train import train, test, early_training_checks
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
|
# only one GPU for only one training sample
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
|
download_ldc("data/ldc93s1")
|
||||||
|
|
||||||
|
initialize_globals_from_args(
|
||||||
|
load_train="init",
|
||||||
|
alphabet_config_path="data/alphabet.txt",
|
||||||
|
train_files=["data/ldc93s1/ldc93s1.csv"],
|
||||||
|
dev_files=["data/ldc93s1/ldc93s1.csv"],
|
||||||
|
test_files=["data/ldc93s1/ldc93s1.csv"],
|
||||||
|
augment=["time_mask"],
|
||||||
|
n_hidden=100,
|
||||||
|
epochs=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
early_training_checks()
|
||||||
|
|
||||||
|
train()
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
test()
|
@ -20,7 +20,8 @@ fi
|
|||||||
# and when trying to run on multiple devices (like GPUs), this will break
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
python -u train.py --show_progressbar false \
|
python -u train.py --alphabet_config_path "data/alphabet.txt" \
|
||||||
|
--show_progressbar false \
|
||||||
--train_files data/ldc93s1/ldc93s1.csv \
|
--train_files data/ldc93s1/ldc93s1.csv \
|
||||||
--test_files data/ldc93s1/ldc93s1.csv \
|
--test_files data/ldc93s1/ldc93s1.csv \
|
||||||
--train_batch_size 1 \
|
--train_batch_size 1 \
|
||||||
|
@ -66,3 +66,6 @@ time ./bin/run-ci-ldc93s1_checkpoint_sdb.sh
|
|||||||
|
|
||||||
# Bytes output mode, resuming from checkpoint
|
# Bytes output mode, resuming from checkpoint
|
||||||
time ./bin/run-ci-ldc93s1_checkpoint_bytes.sh
|
time ./bin/run-ci-ldc93s1_checkpoint_bytes.sh
|
||||||
|
|
||||||
|
# Training with args set via initialize_globals_from_args()
|
||||||
|
time python ./bin/run-ldc93s1.py
|
||||||
|
@ -10,7 +10,7 @@ import tensorflow.compat.v1 as tfv1
|
|||||||
from coqui_stt_ctcdecoder import Scorer
|
from coqui_stt_ctcdecoder import Scorer
|
||||||
from coqui_stt_training.evaluate import evaluate
|
from coqui_stt_training.evaluate import evaluate
|
||||||
from coqui_stt_training.train import create_model
|
from coqui_stt_training.train import create_model
|
||||||
from coqui_stt_training.util.config import Config, initialize_globals
|
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
||||||
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
||||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
from coqui_stt_training.util.flags import FLAGS, create_flags
|
||||||
from coqui_stt_training.util.logging import log_error
|
from coqui_stt_training.util.logging import log_error
|
||||||
@ -52,7 +52,7 @@ def objective(trial):
|
|||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
initialize_globals()
|
initialize_globals_from_cli()
|
||||||
|
|
||||||
if not FLAGS.test_files:
|
if not FLAGS.test_files:
|
||||||
log_error(
|
log_error(
|
||||||
|
@ -18,7 +18,7 @@ from .util.checkpoints import load_graph_for_evaluation
|
|||||||
from .util.config import (
|
from .util.config import (
|
||||||
Config,
|
Config,
|
||||||
create_progressbar,
|
create_progressbar,
|
||||||
initialize_globals,
|
initialize_globals_from_cli,
|
||||||
log_error,
|
log_error,
|
||||||
log_progress,
|
log_progress,
|
||||||
)
|
)
|
||||||
@ -169,7 +169,7 @@ def evaluate(test_csvs, create_model):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
initialize_globals()
|
initialize_globals_from_cli()
|
||||||
|
|
||||||
if not Config.test_files:
|
if not Config.test_files:
|
||||||
log_error(
|
log_error(
|
||||||
|
@ -45,7 +45,7 @@ from .util.checkpoints import (
|
|||||||
from .util.config import (
|
from .util.config import (
|
||||||
Config,
|
Config,
|
||||||
create_progressbar,
|
create_progressbar,
|
||||||
initialize_globals,
|
initialize_globals_from_cli,
|
||||||
log_debug,
|
log_debug,
|
||||||
log_error,
|
log_error,
|
||||||
log_info,
|
log_info,
|
||||||
@ -1249,7 +1249,7 @@ def early_training_checks():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
initialize_globals()
|
initialize_globals_from_cli()
|
||||||
early_training_checks()
|
early_training_checks()
|
||||||
|
|
||||||
if Config.train_files:
|
if Config.train_files:
|
||||||
|
@ -36,10 +36,201 @@ Config = _ConfigSingleton() # pylint: disable=invalid-name
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _SttConfig(Coqpit):
|
class _SttConfig(Coqpit):
|
||||||
|
def __post_init__(self):
|
||||||
|
# Augmentations
|
||||||
|
self.augmentations = parse_augmentations(self.augment)
|
||||||
|
if self.augmentations:
|
||||||
|
print(f"Parsed augmentations: {self.augmentations}")
|
||||||
|
if self.augmentations and self.feature_cache and self.cache_for_epochs == 0:
|
||||||
|
print(
|
||||||
|
"Due to your feature-cache settings, augmentations of "
|
||||||
|
"the first epoch will be repeated on all following epochs. "
|
||||||
|
"This may lead to unintended over-fitting. "
|
||||||
|
"You can use --cache_for_epochs <n_epochs> to invalidate "
|
||||||
|
"the cache after a given number of epochs."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.normalize_sample_rate:
|
||||||
|
self.augmentations = [NormalizeSampleRate(self.audio_sample_rate)] + self[
|
||||||
|
"augmentations"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Caching
|
||||||
|
if self.cache_for_epochs == 1:
|
||||||
|
print(
|
||||||
|
"--cache_for_epochs == 1 is (re-)creating the feature cache "
|
||||||
|
"on every epoch but will never use it. You can either set "
|
||||||
|
"--cache_for_epochs > 1, or not use feature caching at all."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read-buffer
|
||||||
|
self.read_buffer = parse_file_size(self.read_buffer)
|
||||||
|
|
||||||
|
# Set default dropout rates
|
||||||
|
if self.dropout_rate2 < 0:
|
||||||
|
self.dropout_rate2 = self.dropout_rate
|
||||||
|
if self.dropout_rate3 < 0:
|
||||||
|
self.dropout_rate3 = self.dropout_rate
|
||||||
|
if self.dropout_rate6 < 0:
|
||||||
|
self.dropout_rate6 = self.dropout_rate
|
||||||
|
|
||||||
|
# Set default checkpoint dir
|
||||||
|
# If separate save and load flags were not specified, default to load and save
|
||||||
|
# from the same dir.
|
||||||
|
|
||||||
|
# if save_checkpoint_dir hasn't been set, or checkpoint_dir is new
|
||||||
|
if (not self.save_checkpoint_dir) or (
|
||||||
|
self.save_checkpoint_dir is not self.checkpoint_dir
|
||||||
|
):
|
||||||
|
if not self.checkpoint_dir:
|
||||||
|
self.checkpoint_dir = xdg.save_data_path(
|
||||||
|
os.path.join("stt", "checkpoints")
|
||||||
|
)
|
||||||
|
self.save_checkpoint_dir = self.checkpoint_dir
|
||||||
|
else:
|
||||||
|
self.save_checkpoint_dir = self.checkpoint_dir
|
||||||
|
# if load_checkpoint_dir hasn't been set, or checkpoint_dir is new
|
||||||
|
if (not self.load_checkpoint_dir) or (
|
||||||
|
self.load_checkpoint_dir is not self.checkpoint_dir
|
||||||
|
):
|
||||||
|
if not self.checkpoint_dir:
|
||||||
|
self.checkpoint_dir = xdg.load_data_path(
|
||||||
|
os.path.join("stt", "checkpoints")
|
||||||
|
)
|
||||||
|
self.load_checkpoint_dir = self.checkpoint_dir
|
||||||
|
else:
|
||||||
|
self.load_checkpoint_dir = self.checkpoint_dir
|
||||||
|
|
||||||
|
if self.load_train not in ["last", "best", "init", "auto"]:
|
||||||
|
self.load_train = "auto"
|
||||||
|
|
||||||
|
if self.load_evaluate not in ["last", "best", "auto"]:
|
||||||
|
self.load_evaluate = "auto"
|
||||||
|
|
||||||
|
# Set default summary dir
|
||||||
|
if not self.summary_dir:
|
||||||
|
self.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
||||||
|
|
||||||
|
# Standard session configuration that'll be used for all new sessions.
|
||||||
|
self.session_config = tfv1.ConfigProto(
|
||||||
|
allow_soft_placement=True,
|
||||||
|
log_device_placement=self.log_placement,
|
||||||
|
inter_op_parallelism_threads=self.inter_op_parallelism_threads,
|
||||||
|
intra_op_parallelism_threads=self.intra_op_parallelism_threads,
|
||||||
|
gpu_options=tfv1.GPUOptions(allow_growth=self.use_allow_growth),
|
||||||
|
)
|
||||||
|
|
||||||
|
# CPU device
|
||||||
|
self.cpu_device = "/cpu:0"
|
||||||
|
|
||||||
|
# Available GPU devices
|
||||||
|
self.available_devices = get_available_gpus(self.session_config)
|
||||||
|
|
||||||
|
# If there is no GPU available, we fall back to CPU based operation
|
||||||
|
if not self.available_devices:
|
||||||
|
self.available_devices = [self.cpu_device]
|
||||||
|
|
||||||
|
if self.bytes_output_mode and self.alphabet_config_path:
|
||||||
|
raise RuntimeError(
|
||||||
|
"You cannot set --alphabet_config_path *and* --bytes_output_mode"
|
||||||
|
)
|
||||||
|
elif self.bytes_output_mode:
|
||||||
|
self.alphabet = UTF8Alphabet()
|
||||||
|
elif self.alphabet_config_path:
|
||||||
|
self.alphabet = Alphabet(os.path.abspath(self.alphabet_config_path))
|
||||||
|
|
||||||
|
# Geometric Constants
|
||||||
|
# ===================
|
||||||
|
|
||||||
|
# For an explanation of the meaning of the geometric constants
|
||||||
|
# please refer to doc/Geometry.md
|
||||||
|
|
||||||
|
# Number of MFCC features
|
||||||
|
self.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
||||||
|
|
||||||
|
# The number of frames in the context
|
||||||
|
self.n_context = (
|
||||||
|
9 # TODO: Determine the optimal value using a validation data set
|
||||||
|
)
|
||||||
|
|
||||||
|
# Number of units in hidden layers
|
||||||
|
self.n_hidden = self.n_hidden
|
||||||
|
|
||||||
|
self.n_hidden_1 = self.n_hidden
|
||||||
|
|
||||||
|
self.n_hidden_2 = self.n_hidden
|
||||||
|
|
||||||
|
self.n_hidden_5 = self.n_hidden
|
||||||
|
|
||||||
|
# LSTM cell state dimension
|
||||||
|
self.n_cell_dim = self.n_hidden
|
||||||
|
|
||||||
|
# The number of units in the third layer, which feeds in to the LSTM
|
||||||
|
self.n_hidden_3 = self.n_cell_dim
|
||||||
|
|
||||||
|
# Dims in last layer = number of characters in alphabet plus one
|
||||||
|
try:
|
||||||
|
# +1 for CTC blank label
|
||||||
|
self.n_hidden_6 = self.alphabet.GetSize() + 1
|
||||||
|
except:
|
||||||
|
AttributeError
|
||||||
|
|
||||||
|
# Size of audio window in samples
|
||||||
|
if (self.feature_win_len * self.audio_sample_rate) % 1000 != 0:
|
||||||
|
log_error(
|
||||||
|
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||||
|
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||||
|
"your --feature_win_len value or resample your audio accordingly."
|
||||||
|
"".format(
|
||||||
|
self.feature_win_len,
|
||||||
|
self.feature_win_len / 1000,
|
||||||
|
self.audio_sample_rate,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
self.audio_window_samples = self.audio_sample_rate * (
|
||||||
|
self.feature_win_len / 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stride for feature computations in samples
|
||||||
|
if (self.feature_win_step * self.audio_sample_rate) % 1000 != 0:
|
||||||
|
log_error(
|
||||||
|
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
||||||
|
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||||
|
"your --feature_win_step value or resample your audio accordingly."
|
||||||
|
"".format(
|
||||||
|
self.feature_win_step,
|
||||||
|
self.feature_win_step / 1000,
|
||||||
|
self.audio_sample_rate,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
self.audio_step_samples = self.audio_sample_rate * (
|
||||||
|
self.feature_win_step / 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.one_shot_infer:
|
||||||
|
if not path_exists_remote(self.one_shot_infer):
|
||||||
|
log_error("Path specified in --one_shot_infer is not a valid file.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if self.train_cudnn and self.load_cudnn:
|
||||||
|
log_error(
|
||||||
|
"Trying to use --train_cudnn, but --load_cudnn "
|
||||||
|
"was also specified. The --load_cudnn flag is only "
|
||||||
|
"needed when converting a CuDNN RNN checkpoint to "
|
||||||
|
"a CPU-capable graph. If your system is capable of "
|
||||||
|
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
||||||
|
"checkpoint normally with --save_checkpoint_dir."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
train_files: List[str] = field(
|
train_files: List[str] = field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
metadata=dict(
|
metadata=dict(
|
||||||
help="space-separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run."
|
help="space-separated list of files specifying the datasets used for training. Multiple files will get merged. If empty, training will not be run."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
dev_files: List[str] = field(
|
dev_files: List[str] = field(
|
||||||
@ -472,7 +663,7 @@ class _SttConfig(Coqpit):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
alphabet_config_path: str = field(
|
alphabet_config_path: str = field(
|
||||||
default="data/alphabet.txt",
|
default="",
|
||||||
metadata=dict(
|
metadata=dict(
|
||||||
help="path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format."
|
help="path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format."
|
||||||
),
|
),
|
||||||
@ -540,166 +731,17 @@ class _SttConfig(Coqpit):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_values(self):
|
|
||||||
c = asdict(self)
|
|
||||||
check_argument("alphabet_config_path", c, is_path=True)
|
|
||||||
check_argument("one_shot_infer", c, is_path=True)
|
|
||||||
|
|
||||||
|
def initialize_globals_from_cli():
|
||||||
def initialize_globals():
|
|
||||||
c = _SttConfig()
|
c = _SttConfig()
|
||||||
c.parse_args(arg_prefix="")
|
c.parse_args(arg_prefix="")
|
||||||
|
c.__post_init__()
|
||||||
|
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
||||||
# Augmentations
|
|
||||||
c.augmentations = parse_augmentations(c.augment)
|
|
||||||
print(f"Parsed augmentations from flags: {c.augmentations}")
|
|
||||||
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
|
||||||
print(
|
|
||||||
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
|
||||||
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
|
|
||||||
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
|
|
||||||
)
|
|
||||||
|
|
||||||
if c.normalize_sample_rate:
|
|
||||||
c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[
|
|
||||||
"augmentations"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Caching
|
|
||||||
if c.cache_for_epochs == 1:
|
|
||||||
print(
|
|
||||||
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read-buffer
|
|
||||||
c.read_buffer = parse_file_size(c.read_buffer)
|
|
||||||
|
|
||||||
# Set default dropout rates
|
|
||||||
if c.dropout_rate2 < 0:
|
|
||||||
c.dropout_rate2 = c.dropout_rate
|
|
||||||
if c.dropout_rate3 < 0:
|
|
||||||
c.dropout_rate3 = c.dropout_rate
|
|
||||||
if c.dropout_rate6 < 0:
|
|
||||||
c.dropout_rate6 = c.dropout_rate
|
|
||||||
|
|
||||||
# Set default checkpoint dir
|
|
||||||
if not c.checkpoint_dir:
|
|
||||||
c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
|
|
||||||
|
|
||||||
if c.load_train not in ["last", "best", "init", "auto"]:
|
|
||||||
c.load_train = "auto"
|
|
||||||
|
|
||||||
if c.load_evaluate not in ["last", "best", "auto"]:
|
|
||||||
c.load_evaluate = "auto"
|
|
||||||
|
|
||||||
# Set default summary dir
|
|
||||||
if not c.summary_dir:
|
|
||||||
c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
|
||||||
|
|
||||||
# Standard session configuration that'll be used for all new sessions.
|
|
||||||
c.session_config = tfv1.ConfigProto(
|
|
||||||
allow_soft_placement=True,
|
|
||||||
log_device_placement=c.log_placement,
|
|
||||||
inter_op_parallelism_threads=c.inter_op_parallelism_threads,
|
|
||||||
intra_op_parallelism_threads=c.intra_op_parallelism_threads,
|
|
||||||
gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth),
|
|
||||||
)
|
|
||||||
|
|
||||||
# CPU device
|
|
||||||
c.cpu_device = "/cpu:0"
|
|
||||||
|
|
||||||
# Available GPU devices
|
|
||||||
c.available_devices = get_available_gpus(c.session_config)
|
|
||||||
|
|
||||||
# If there is no GPU available, we fall back to CPU based operation
|
|
||||||
if not c.available_devices:
|
|
||||||
c.available_devices = [c.cpu_device]
|
|
||||||
|
|
||||||
if c.bytes_output_mode:
|
|
||||||
c.alphabet = UTF8Alphabet()
|
|
||||||
else:
|
|
||||||
c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path))
|
|
||||||
|
|
||||||
# Geometric Constants
|
|
||||||
# ===================
|
|
||||||
|
|
||||||
# For an explanation of the meaning of the geometric constants, please refer to
|
|
||||||
# doc/Geometry.md
|
|
||||||
|
|
||||||
# Number of MFCC features
|
|
||||||
c.n_input = 26 # TODO: Determine this programmatically from the sample rate
|
|
||||||
|
|
||||||
# The number of frames in the context
|
|
||||||
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
|
|
||||||
|
|
||||||
# Number of units in hidden layers
|
|
||||||
c.n_hidden = c.n_hidden
|
|
||||||
|
|
||||||
c.n_hidden_1 = c.n_hidden
|
|
||||||
|
|
||||||
c.n_hidden_2 = c.n_hidden
|
|
||||||
|
|
||||||
c.n_hidden_5 = c.n_hidden
|
|
||||||
|
|
||||||
# LSTM cell state dimension
|
|
||||||
c.n_cell_dim = c.n_hidden
|
|
||||||
|
|
||||||
# The number of units in the third layer, which feeds in to the LSTM
|
|
||||||
c.n_hidden_3 = c.n_cell_dim
|
|
||||||
|
|
||||||
# Units in the sixth layer = number of characters in the target language plus one
|
|
||||||
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
|
||||||
|
|
||||||
# Size of audio window in samples
|
|
||||||
if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0:
|
|
||||||
log_error(
|
|
||||||
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
|
||||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
|
||||||
"your --feature_win_len value or resample your audio accordingly."
|
|
||||||
"".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate)
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000)
|
|
||||||
|
|
||||||
# Stride for feature computations in samples
|
|
||||||
if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0:
|
|
||||||
log_error(
|
|
||||||
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
|
||||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
|
||||||
"your --feature_win_step value or resample your audio accordingly."
|
|
||||||
"".format(
|
|
||||||
c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000)
|
|
||||||
|
|
||||||
if c.one_shot_infer:
|
|
||||||
if not path_exists_remote(c.one_shot_infer):
|
|
||||||
log_error("Path specified in --one_shot_infer is not a valid file.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if c.train_cudnn and c.load_cudnn:
|
|
||||||
log_error(
|
|
||||||
"Trying to use --train_cudnn, but --load_cudnn "
|
|
||||||
"was also specified. The --load_cudnn flag is only "
|
|
||||||
"needed when converting a CuDNN RNN checkpoint to "
|
|
||||||
"a CPU-capable graph. If your system is capable of "
|
|
||||||
"using CuDNN RNN, you can just specify the CuDNN RNN "
|
|
||||||
"checkpoint normally with --save_checkpoint_dir."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# If separate save and load flags were not specified, default to load and save
|
|
||||||
# from the same dir.
|
|
||||||
if not c.save_checkpoint_dir:
|
|
||||||
c.save_checkpoint_dir = c.checkpoint_dir
|
|
||||||
|
|
||||||
if not c.load_checkpoint_dir:
|
|
||||||
c.load_checkpoint_dir = c.checkpoint_dir
|
|
||||||
|
|
||||||
|
def initialize_globals_from_args(**override_args):
|
||||||
|
# Update Config with new args
|
||||||
|
c = _SttConfig(**override_args)
|
||||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,14 +19,19 @@ ValueRange = namedtuple("ValueRange", "start end r")
|
|||||||
|
|
||||||
|
|
||||||
def parse_file_size(file_size):
|
def parse_file_size(file_size):
|
||||||
file_size = file_size.lower().strip()
|
if type(file_size) is str:
|
||||||
if len(file_size) == 0:
|
file_size = file_size.lower().strip()
|
||||||
return 0
|
if len(file_size) == 0:
|
||||||
n = int(keep_only_digits(file_size))
|
return 0
|
||||||
if file_size[-1] == "b":
|
n = int(keep_only_digits(file_size))
|
||||||
file_size = file_size[:-1]
|
if file_size[-1] == "b":
|
||||||
e = file_size[-1]
|
file_size = file_size[:-1]
|
||||||
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
|
e = file_size[-1]
|
||||||
|
return SIZE_PREFIX_LOOKUP[e] * n if e in SIZE_PREFIX_LOOKUP else n
|
||||||
|
elif type(file_size) is int:
|
||||||
|
return file_size
|
||||||
|
else:
|
||||||
|
raise ValueError("file_size not of type 'int' or 'str'")
|
||||||
|
|
||||||
|
|
||||||
def keep_only_digits(txt):
|
def keep_only_digits(txt):
|
||||||
|
@ -20,7 +20,7 @@ from multiprocessing import Process, cpu_count
|
|||||||
|
|
||||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||||
from coqui_stt_training.util.audio import AudioFile
|
from coqui_stt_training.util.audio import AudioFile
|
||||||
from coqui_stt_training.util.config import Config, initialize_globals
|
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
||||||
from coqui_stt_training.util.feeding import split_audio_file
|
from coqui_stt_training.util.feeding import split_audio_file
|
||||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
from coqui_stt_training.util.flags import FLAGS, create_flags
|
||||||
from coqui_stt_training.util.logging import (
|
from coqui_stt_training.util.logging import (
|
||||||
@ -42,7 +42,8 @@ def transcribe_file(audio_path, tlog_path):
|
|||||||
)
|
)
|
||||||
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
from coqui_stt_training.util.checkpoints import load_graph_for_evaluation
|
||||||
|
|
||||||
initialize_globals()
|
initialize_globals_from_cli()
|
||||||
|
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||||
try:
|
try:
|
||||||
num_processes = cpu_count()
|
num_processes = cpu_count()
|
||||||
|
Loading…
Reference in New Issue
Block a user