Switch flag/config handling to Coqpit
This commit is contained in:
parent
fb826f714d
commit
5ad6e6abbf
|
@ -14,16 +14,16 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--scorer "" \
|
||||
--augment dropout \
|
||||
--augment pitch \
|
||||
--augment tempo \
|
||||
--augment warp \
|
||||
--augment time_mask \
|
||||
--augment frequency_mask \
|
||||
--augment add \
|
||||
--augment multiply \
|
||||
pitch \
|
||||
tempo \
|
||||
warp \
|
||||
time_mask \
|
||||
frequency_mask \
|
||||
add \
|
||||
multiply \
|
||||
--n_hidden 100 \
|
||||
--epochs 1
|
||||
|
|
|
@ -14,7 +14,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
|
|
|
@ -14,12 +14,12 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 100 --epochs 1 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_bytes' --bytes_output_mode \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt_bytes' --bytes_output_mode true \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.bytes.scorer' | tee /tmp/resume.log
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||
|
|
|
@ -17,7 +17,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--feature_cache '/tmp/ldc93s1_cache' \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
|
|
|
@ -17,7 +17,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--feature_cache '/tmp/ldc93s1_cache' \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
|
@ -27,4 +27,4 @@ python -u train.py --noshow_progressbar --noearly_stop \
|
|||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train_bytes' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.bytes.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--bytes_output_mode
|
||||
--bytes_output_mode true
|
||||
|
|
|
@ -16,11 +16,11 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar \
|
||||
python -u train.py --show_progressbar false \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt_bytes' \
|
||||
--export_dir '/tmp/train_bytes_tflite' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.bytes.scorer' \
|
||||
--bytes_output_mode \
|
||||
--bytes_output_mode true \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--export_tflite
|
||||
--export_tflite true
|
||||
|
|
|
@ -17,7 +17,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
|
|
|
@ -23,7 +23,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_sdb} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_sdb} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_sdb} --test_batch_size 1 \
|
||||
|
|
|
@ -23,7 +23,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_sdb},${ldc93s1_csv} --train_batch_size 1 \
|
||||
--feature_cache '/tmp/ldc93s1_cache_sdb_csv' \
|
||||
--dev_files ${ldc93s1_sdb},${ldc93s1_csv} --dev_batch_size 1 \
|
||||
|
|
|
@ -14,7 +14,7 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
|
|
|
@ -16,21 +16,21 @@ fi;
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar \
|
||||
python -u train.py --show_progressbar false \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--export_tflite
|
||||
--export_tflite true
|
||||
|
||||
mkdir /tmp/train_tflite/en-us
|
||||
|
||||
python -u train.py --noshow_progressbar \
|
||||
python -u train.py --show_progressbar false \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite/en-us' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--export_language 'Fake English (fk-FK)' \
|
||||
--export_zip
|
||||
--export_zip true
|
||||
|
|
|
@ -29,7 +29,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
|||
echo "########################################################"
|
||||
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||
echo "########################################################"
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load_train "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
|
@ -43,7 +43,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
|||
echo "##############################################################################"
|
||||
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "##############################################################################"
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load_train "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
|
@ -58,7 +58,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
|||
echo "####################################################################################"
|
||||
echo "#### Transfer to RUSSIAN model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "####################################################################################"
|
||||
python -u train.py --noshow_progressbar --noearly_stop \
|
||||
python -u train.py --show_progressbar false --early_stop false \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load_train 'last' \
|
||||
|
@ -71,7 +71,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
|||
--epochs 10
|
||||
|
||||
# Test transfer learning checkpoint
|
||||
python -u evaluate.py --noshow_progressbar \
|
||||
python -u evaluate.py --show_progressbar false \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
|
|
|
@ -20,7 +20,7 @@ fi
|
|||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python -u train.py --noshow_progressbar \
|
||||
python -u train.py --show_progressbar false \
|
||||
--train_files data/ldc93s1/ldc93s1.csv \
|
||||
--test_files data/ldc93s1/ldc93s1.csv \
|
||||
--train_batch_size 1 \
|
||||
|
|
|
@ -9,4 +9,4 @@ if __name__ == "__main__":
|
|||
print("Training package is not installed. See training documentation.")
|
||||
raise
|
||||
|
||||
ds_evaluate.run_script()
|
||||
ds_evaluate.main()
|
||||
|
|
|
@ -60,7 +60,7 @@ def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
|
|||
queue_in.task_done()
|
||||
|
||||
|
||||
def main(args, _):
|
||||
def main(args):
|
||||
manager = Manager()
|
||||
work_todo = JoinableQueue() # this is where we are going to store input data
|
||||
work_done = manager.Queue() # this where we are gonna push them out
|
||||
|
@ -154,5 +154,4 @@ def parse_args():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_flags()
|
||||
absl.app.run(partial(main, parse_args()))
|
||||
main(parse_args())
|
||||
|
|
2
setup.py
2
setup.py
|
@ -13,9 +13,9 @@ def main():
|
|||
version = fin.read().strip()
|
||||
|
||||
install_requires_base = [
|
||||
"absl-py",
|
||||
"attrdict",
|
||||
"bs4",
|
||||
"coqpit",
|
||||
"numpy",
|
||||
"optuna",
|
||||
"opuslib == 2.0.0",
|
||||
|
|
2
train.py
2
train.py
|
@ -9,4 +9,4 @@ if __name__ == "__main__":
|
|||
print("Training package is not installed. See training documentation.")
|
||||
raise
|
||||
|
||||
ds_train.run_script()
|
||||
ds_train.main()
|
||||
|
|
|
@ -6,7 +6,6 @@ import json
|
|||
import sys
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
import absl.app
|
||||
import progressbar
|
||||
from coqui_stt_ctcdecoder import Scorer, ctc_beam_search_decoder_batch
|
||||
from six.moves import zip
|
||||
|
@ -16,12 +15,16 @@ import tensorflow.compat.v1 as tfv1
|
|||
|
||||
from .util.augmentations import NormalizeSampleRate
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.config import (
|
||||
Config,
|
||||
create_progressbar,
|
||||
initialize_globals,
|
||||
log_error,
|
||||
log_progress,
|
||||
)
|
||||
from .util.evaluate_tools import calculate_and_print_report, save_samples_json
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import FLAGS, create_flags
|
||||
from .util.helpers import check_ctcdecoder_version
|
||||
from .util.logging import create_progressbar, log_error, log_progress
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
|
@ -47,9 +50,9 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
|||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
else:
|
||||
scorer = None
|
||||
|
@ -57,11 +60,11 @@ def evaluate(test_csvs, create_model):
|
|||
test_sets = [
|
||||
create_dataset(
|
||||
[csv],
|
||||
batch_size=FLAGS.test_batch_size,
|
||||
batch_size=Config.test_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)],
|
||||
reverse=FLAGS.reverse_test,
|
||||
limit=FLAGS.limit_test,
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
reverse=Config.reverse_test,
|
||||
limit=Config.limit_test,
|
||||
)
|
||||
for csv in test_csvs
|
||||
]
|
||||
|
@ -132,11 +135,11 @@ def evaluate(test_csvs, create_model):
|
|||
batch_logits,
|
||||
batch_lengths,
|
||||
Config.alphabet,
|
||||
FLAGS.beam_width,
|
||||
Config.beam_width,
|
||||
num_processes=num_processes,
|
||||
scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n,
|
||||
cutoff_prob=Config.cutoff_prob,
|
||||
cutoff_top_n=Config.cutoff_top_n,
|
||||
)
|
||||
predictions.extend(d[0][1] for d in decoded)
|
||||
ground_truths.extend(
|
||||
|
@ -165,10 +168,10 @@ def evaluate(test_csvs, create_model):
|
|||
return samples
|
||||
|
||||
|
||||
def main(_):
|
||||
def main():
|
||||
initialize_globals()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
if not Config.test_files:
|
||||
log_error(
|
||||
"You need to specify what files to use for evaluation via "
|
||||
"the --test_files flag."
|
||||
|
@ -179,16 +182,11 @@ def main(_):
|
|||
create_model,
|
||||
)
|
||||
|
||||
samples = evaluate(FLAGS.test_files.split(","), create_model)
|
||||
samples = evaluate(Config.test_files, create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
save_samples_json(samples, FLAGS.test_output_file)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
if Config.test_output_file:
|
||||
save_samples_json(samples, Config.test_output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_script()
|
||||
main()
|
||||
|
|
|
@ -11,10 +11,10 @@ DESIRED_LOG_LEVEL = (
|
|||
)
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import absl.app
|
||||
import numpy as np
|
||||
import progressbar
|
||||
|
||||
|
@ -42,10 +42,18 @@ from .util.checkpoints import (
|
|||
load_or_init_graph_for_training,
|
||||
reload_best_checkpoint,
|
||||
)
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.config import (
|
||||
Config,
|
||||
create_progressbar,
|
||||
initialize_globals,
|
||||
log_debug,
|
||||
log_error,
|
||||
log_info,
|
||||
log_progress,
|
||||
log_warn,
|
||||
)
|
||||
from .util.evaluate_tools import save_samples_json
|
||||
from .util.feeding import audio_to_features, audiofile_to_features, create_dataset
|
||||
from .util.flags import FLAGS, create_flags
|
||||
from .util.helpers import ExceptionBox, check_ctcdecoder_version
|
||||
from .util.io import (
|
||||
is_remote_path,
|
||||
|
@ -54,14 +62,6 @@ from .util.io import (
|
|||
open_remote,
|
||||
remove_remote,
|
||||
)
|
||||
from .util.logging import (
|
||||
create_progressbar,
|
||||
log_debug,
|
||||
log_error,
|
||||
log_info,
|
||||
log_progress,
|
||||
log_warn,
|
||||
)
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
|
@ -120,7 +120,7 @@ def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False):
|
|||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||
|
||||
if relu:
|
||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||
output = tf.minimum(tf.nn.relu(output), Config.relu_clip)
|
||||
|
||||
if layer_norm:
|
||||
with tfv1.variable_scope(name):
|
||||
|
@ -249,21 +249,21 @@ def create_model(
|
|||
batch_x,
|
||||
Config.n_hidden_1,
|
||||
dropout_rate=dropout[0],
|
||||
layer_norm=FLAGS.layer_norm,
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
layers["layer_2"] = layer_2 = dense(
|
||||
"layer_2",
|
||||
layer_1,
|
||||
Config.n_hidden_2,
|
||||
dropout_rate=dropout[1],
|
||||
layer_norm=FLAGS.layer_norm,
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
layers["layer_3"] = layer_3 = dense(
|
||||
"layer_3",
|
||||
layer_2,
|
||||
Config.n_hidden_3,
|
||||
dropout_rate=dropout[2],
|
||||
layer_norm=FLAGS.layer_norm,
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
|
||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||
|
@ -286,7 +286,7 @@ def create_model(
|
|||
output,
|
||||
Config.n_hidden_5,
|
||||
dropout_rate=dropout[5],
|
||||
layer_norm=FLAGS.layer_norm,
|
||||
layer_norm=Config.layer_norm,
|
||||
)
|
||||
|
||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||
|
@ -326,7 +326,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
|||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.train_cudnn:
|
||||
if Config.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
@ -365,9 +365,9 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
|||
def create_optimizer(learning_rate_var):
|
||||
optimizer = tfv1.train.AdamOptimizer(
|
||||
learning_rate=learning_rate_var,
|
||||
beta1=FLAGS.beta1,
|
||||
beta2=FLAGS.beta2,
|
||||
epsilon=FLAGS.epsilon,
|
||||
beta1=Config.beta1,
|
||||
beta2=Config.beta2,
|
||||
epsilon=Config.epsilon,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
@ -526,17 +526,17 @@ def train():
|
|||
|
||||
# Create training and validation datasets
|
||||
train_set = create_dataset(
|
||||
FLAGS.train_files.split(","),
|
||||
batch_size=FLAGS.train_batch_size,
|
||||
epochs=FLAGS.epochs,
|
||||
Config.train_files,
|
||||
batch_size=Config.train_batch_size,
|
||||
epochs=Config.epochs,
|
||||
augmentations=Config.augmentations,
|
||||
cache_path=FLAGS.feature_cache,
|
||||
cache_path=Config.feature_cache,
|
||||
train_phase=True,
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
|
||||
reverse=FLAGS.reverse_train,
|
||||
limit=FLAGS.limit_train,
|
||||
buffering=FLAGS.read_buffer,
|
||||
process_ahead=len(Config.available_devices) * Config.train_batch_size * 2,
|
||||
reverse=Config.reverse_train,
|
||||
limit=Config.limit_train,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
|
||||
iterator = tfv1.data.Iterator.from_structure(
|
||||
|
@ -548,37 +548,37 @@ def train():
|
|||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
dev_sources = FLAGS.dev_files.split(",")
|
||||
if Config.dev_files:
|
||||
dev_sources = Config.dev_files
|
||||
dev_sets = [
|
||||
create_dataset(
|
||||
[source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)],
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer,
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
for source in dev_sources
|
||||
]
|
||||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
|
||||
|
||||
if FLAGS.metrics_files:
|
||||
metrics_sources = FLAGS.metrics_files.split(",")
|
||||
if Config.metrics_files:
|
||||
metrics_sources = Config.metrics_files
|
||||
metrics_sets = [
|
||||
create_dataset(
|
||||
[source],
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
batch_size=Config.dev_batch_size,
|
||||
train_phase=False,
|
||||
augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)],
|
||||
augmentations=[NormalizeSampleRate(Config.audio_sample_rate)],
|
||||
exception_box=exception_box,
|
||||
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
|
||||
reverse=FLAGS.reverse_dev,
|
||||
limit=FLAGS.limit_dev,
|
||||
buffering=FLAGS.read_buffer,
|
||||
process_ahead=len(Config.available_devices) * Config.dev_batch_size * 2,
|
||||
reverse=Config.reverse_dev,
|
||||
limit=Config.limit_dev,
|
||||
buffering=Config.read_buffer,
|
||||
)
|
||||
for source in metrics_sources
|
||||
]
|
||||
|
@ -591,26 +591,26 @@ def train():
|
|||
tfv1.placeholder(tf.float32, name="dropout_{}".format(i)) for i in range(6)
|
||||
]
|
||||
dropout_feed_dict = {
|
||||
dropout_rates[0]: FLAGS.dropout_rate,
|
||||
dropout_rates[1]: FLAGS.dropout_rate2,
|
||||
dropout_rates[2]: FLAGS.dropout_rate3,
|
||||
dropout_rates[3]: FLAGS.dropout_rate4,
|
||||
dropout_rates[4]: FLAGS.dropout_rate5,
|
||||
dropout_rates[5]: FLAGS.dropout_rate6,
|
||||
dropout_rates[0]: Config.dropout_rate,
|
||||
dropout_rates[1]: Config.dropout_rate2,
|
||||
dropout_rates[2]: Config.dropout_rate3,
|
||||
dropout_rates[3]: Config.dropout_rate4,
|
||||
dropout_rates[4]: Config.dropout_rate5,
|
||||
dropout_rates[5]: Config.dropout_rate6,
|
||||
}
|
||||
no_dropout_feed_dict = {rate: 0.0 for rate in dropout_rates}
|
||||
|
||||
# Building the graph
|
||||
learning_rate_var = tfv1.get_variable(
|
||||
"learning_rate", initializer=FLAGS.learning_rate, trainable=False
|
||||
"learning_rate", initializer=Config.learning_rate, trainable=False
|
||||
)
|
||||
reduce_learning_rate_op = learning_rate_var.assign(
|
||||
tf.multiply(learning_rate_var, FLAGS.plateau_reduction)
|
||||
tf.multiply(learning_rate_var, Config.plateau_reduction)
|
||||
)
|
||||
optimizer = create_optimizer(learning_rate_var)
|
||||
|
||||
# Enable mixed precision training
|
||||
if FLAGS.automatic_mixed_precision:
|
||||
if Config.automatic_mixed_precision:
|
||||
log_info("Enabling automatic mixed precision training.")
|
||||
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(
|
||||
optimizer
|
||||
|
@ -634,13 +634,13 @@ def train():
|
|||
step_summaries_op = tfv1.summary.merge_all("step_summaries")
|
||||
step_summary_writers = {
|
||||
"train": tfv1.summary.FileWriter(
|
||||
os.path.join(FLAGS.summary_dir, "train"), max_queue=120
|
||||
os.path.join(Config.summary_dir, "train"), max_queue=120
|
||||
),
|
||||
"dev": tfv1.summary.FileWriter(
|
||||
os.path.join(FLAGS.summary_dir, "dev"), max_queue=120
|
||||
os.path.join(Config.summary_dir, "dev"), max_queue=120
|
||||
),
|
||||
"metrics": tfv1.summary.FileWriter(
|
||||
os.path.join(FLAGS.summary_dir, "metrics"), max_queue=120
|
||||
os.path.join(Config.summary_dir, "metrics"), max_queue=120
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -651,18 +651,18 @@ def train():
|
|||
}
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, "train")
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=Config.max_to_keep)
|
||||
checkpoint_path = os.path.join(Config.save_checkpoint_dir, "train")
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, "best_dev")
|
||||
best_dev_path = os.path.join(Config.save_checkpoint_dir, "best_dev")
|
||||
|
||||
# Save flags next to checkpoints
|
||||
if not is_remote_path(FLAGS.save_checkpoint_dir):
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, "flags.txt")
|
||||
if not is_remote_path(Config.save_checkpoint_dir):
|
||||
os.makedirs(Config.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(Config.save_checkpoint_dir, "flags.txt")
|
||||
with open_remote(flags_file, "w") as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
json.dump(Config.serialize(), fout, indent=2)
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug("Session opened.")
|
||||
|
@ -684,9 +684,9 @@ def train():
|
|||
step_summary_writer = step_summary_writers.get(set_name)
|
||||
checkpoint_time = time.time()
|
||||
|
||||
if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache:
|
||||
feature_cache_index = FLAGS.feature_cache + ".index"
|
||||
if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(
|
||||
if is_train and Config.cache_for_epochs > 0 and Config.feature_cache:
|
||||
feature_cache_index = Config.feature_cache + ".index"
|
||||
if epoch % Config.cache_for_epochs == 0 and os.path.isfile(
|
||||
feature_cache_index
|
||||
):
|
||||
log_info("Invalidating feature cache")
|
||||
|
@ -766,8 +766,8 @@ def train():
|
|||
|
||||
if (
|
||||
is_train
|
||||
and FLAGS.checkpoint_secs > 0
|
||||
and time.time() - checkpoint_time > FLAGS.checkpoint_secs
|
||||
and Config.checkpoint_secs > 0
|
||||
and time.time() - checkpoint_time > Config.checkpoint_secs
|
||||
):
|
||||
checkpoint_saver.save(
|
||||
session, checkpoint_path, global_step=current_step
|
||||
|
@ -784,7 +784,7 @@ def train():
|
|||
dev_losses = []
|
||||
epochs_without_improvement = 0
|
||||
try:
|
||||
for epoch in range(FLAGS.epochs):
|
||||
for epoch in range(Config.epochs):
|
||||
# Training
|
||||
log_progress("Training epoch %d..." % epoch)
|
||||
train_loss, _ = run_set("train", epoch, train_init_op)
|
||||
|
@ -793,7 +793,7 @@ def train():
|
|||
)
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev_files:
|
||||
if Config.dev_files:
|
||||
# Validation
|
||||
dev_loss = 0.0
|
||||
total_steps = 0
|
||||
|
@ -811,8 +811,8 @@ def train():
|
|||
dev_losses.append(dev_loss)
|
||||
|
||||
# Count epochs without an improvement for early stopping and reduction of learning rate on a plateau
|
||||
# the improvement has to be greater than FLAGS.es_min_delta
|
||||
if dev_loss > best_dev_loss - FLAGS.es_min_delta:
|
||||
# the improvement has to be greater than Config.es_min_delta
|
||||
if dev_loss > best_dev_loss - Config.es_min_delta:
|
||||
epochs_without_improvement += 1
|
||||
else:
|
||||
epochs_without_improvement = 0
|
||||
|
@ -833,8 +833,8 @@ def train():
|
|||
|
||||
# Early stopping
|
||||
if (
|
||||
FLAGS.early_stop
|
||||
and epochs_without_improvement == FLAGS.es_epochs
|
||||
Config.early_stop
|
||||
and epochs_without_improvement == Config.es_epochs
|
||||
):
|
||||
log_info(
|
||||
"Early stop triggered as the loss did not improve the last {} epochs".format(
|
||||
|
@ -845,11 +845,11 @@ def train():
|
|||
|
||||
# Reduce learning rate on plateau
|
||||
# If the learning rate was reduced and there is still no improvement
|
||||
# wait FLAGS.plateau_epochs before the learning rate is reduced again
|
||||
# wait Config.plateau_epochs before the learning rate is reduced again
|
||||
if (
|
||||
FLAGS.reduce_lr_on_plateau
|
||||
Config.reduce_lr_on_plateau
|
||||
and epochs_without_improvement > 0
|
||||
and epochs_without_improvement % FLAGS.plateau_epochs == 0
|
||||
and epochs_without_improvement % Config.plateau_epochs == 0
|
||||
):
|
||||
# Reload checkpoint that we use the best_dev weights again
|
||||
reload_best_checkpoint(session)
|
||||
|
@ -875,7 +875,7 @@ def train():
|
|||
% (save_path)
|
||||
)
|
||||
|
||||
if FLAGS.metrics_files:
|
||||
if Config.metrics_files:
|
||||
# Read only metrics, not affecting best validation loss tracking
|
||||
for source, init_op in zip(metrics_sources, metrics_init_ops):
|
||||
log_progress("Metrics for epoch %d on %s..." % (epoch, source))
|
||||
|
@ -896,9 +896,9 @@ def train():
|
|||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(","), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
save_samples_json(samples, FLAGS.test_output_file)
|
||||
samples = evaluate(Config.test_files, create_model)
|
||||
if Config.test_output_file:
|
||||
save_samples_json(samples, Config.test_output_file)
|
||||
|
||||
|
||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
|
@ -909,7 +909,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
tf.float32, [Config.audio_window_samples], "input_samples"
|
||||
)
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate)
|
||||
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name="mfccs")
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
|
@ -954,7 +954,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
logits, layers = create_model(
|
||||
batch_x=input_tensor,
|
||||
batch_size=batch_size,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
seq_length=seq_length if not Config.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
|
@ -1001,7 +1001,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
"input_samples": input_samples,
|
||||
}
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
if not Config.export_tflite:
|
||||
inputs["input_lengths"] = seq_length
|
||||
|
||||
outputs = {
|
||||
|
@ -1028,9 +1028,9 @@ def export():
|
|||
log_info("Exporting the model...")
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(
|
||||
batch_size=FLAGS.export_batch_size,
|
||||
n_steps=FLAGS.n_steps,
|
||||
tflite=FLAGS.export_tflite,
|
||||
batch_size=Config.export_batch_size,
|
||||
n_steps=Config.n_steps,
|
||||
tflite=Config.export_tflite,
|
||||
)
|
||||
|
||||
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
||||
|
@ -1038,24 +1038,24 @@ def export():
|
|||
|
||||
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
||||
outputs["metadata_sample_rate"] = tf.constant(
|
||||
[FLAGS.audio_sample_rate], name="metadata_sample_rate"
|
||||
[Config.audio_sample_rate], name="metadata_sample_rate"
|
||||
)
|
||||
outputs["metadata_feature_win_len"] = tf.constant(
|
||||
[FLAGS.feature_win_len], name="metadata_feature_win_len"
|
||||
[Config.feature_win_len], name="metadata_feature_win_len"
|
||||
)
|
||||
outputs["metadata_feature_win_step"] = tf.constant(
|
||||
[FLAGS.feature_win_step], name="metadata_feature_win_step"
|
||||
[Config.feature_win_step], name="metadata_feature_win_step"
|
||||
)
|
||||
outputs["metadata_beam_width"] = tf.constant(
|
||||
[FLAGS.export_beam_width], name="metadata_beam_width"
|
||||
[Config.export_beam_width], name="metadata_beam_width"
|
||||
)
|
||||
outputs["metadata_alphabet"] = tf.constant(
|
||||
[Config.alphabet.Serialize()], name="metadata_alphabet"
|
||||
)
|
||||
|
||||
if FLAGS.export_language:
|
||||
if Config.export_language:
|
||||
outputs["metadata_language"] = tf.constant(
|
||||
[FLAGS.export_language.encode("utf-8")], name="metadata_language"
|
||||
[Config.export_language.encode("utf-8")], name="metadata_language"
|
||||
)
|
||||
|
||||
# Prevent further graph changes
|
||||
|
@ -1073,16 +1073,18 @@ def export():
|
|||
# Restore variables from checkpoint
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
output_filename = FLAGS.export_file_name + ".pb"
|
||||
if FLAGS.remove_export:
|
||||
if isdir_remote(FLAGS.export_dir):
|
||||
output_filename = Config.export_file_name + ".pb"
|
||||
if Config.remove_export:
|
||||
if isdir_remote(Config.export_dir):
|
||||
log_info("Removing old export")
|
||||
remove_remote(FLAGS.export_dir)
|
||||
remove_remote(Config.export_dir)
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
output_graph_path = os.path.join(Config.export_dir, output_filename)
|
||||
|
||||
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
if not is_remote_path(Config.export_dir) and not os.path.isdir(
|
||||
Config.export_dir
|
||||
):
|
||||
os.makedirs(Config.export_dir)
|
||||
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
|
@ -1094,12 +1096,12 @@ def export():
|
|||
graph_def=frozen_graph, dest_nodes=output_names
|
||||
)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
if not Config.export_tflite:
|
||||
with open_remote(output_graph_path, "wb") as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
output_tflite_path = os.path.join(
|
||||
FLAGS.export_dir, output_filename.replace(".pb", ".tflite")
|
||||
Config.export_dir, output_filename.replace(".pb", ".tflite")
|
||||
)
|
||||
|
||||
converter = tf.lite.TFLiteConverter(
|
||||
|
@ -1115,27 +1117,29 @@ def export():
|
|||
with open_remote(output_tflite_path, "wb") as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info("Models exported at %s" % (FLAGS.export_dir))
|
||||
log_info("Models exported at %s" % (Config.export_dir))
|
||||
|
||||
metadata_fname = os.path.join(
|
||||
FLAGS.export_dir,
|
||||
Config.export_dir,
|
||||
"{}_{}_{}.md".format(
|
||||
FLAGS.export_author_id, FLAGS.export_model_name, FLAGS.export_model_version
|
||||
Config.export_author_id,
|
||||
Config.export_model_name,
|
||||
Config.export_model_version,
|
||||
),
|
||||
)
|
||||
|
||||
model_runtime = "tflite" if FLAGS.export_tflite else "tensorflow"
|
||||
model_runtime = "tflite" if Config.export_tflite else "tensorflow"
|
||||
with open_remote(metadata_fname, "w") as f:
|
||||
f.write("---\n")
|
||||
f.write("author: {}\n".format(FLAGS.export_author_id))
|
||||
f.write("model_name: {}\n".format(FLAGS.export_model_name))
|
||||
f.write("model_version: {}\n".format(FLAGS.export_model_version))
|
||||
f.write("contact_info: {}\n".format(FLAGS.export_contact_info))
|
||||
f.write("license: {}\n".format(FLAGS.export_license))
|
||||
f.write("language: {}\n".format(FLAGS.export_language))
|
||||
f.write("author: {}\n".format(Config.export_author_id))
|
||||
f.write("model_name: {}\n".format(Config.export_model_name))
|
||||
f.write("model_version: {}\n".format(Config.export_model_version))
|
||||
f.write("contact_info: {}\n".format(Config.export_contact_info))
|
||||
f.write("license: {}\n".format(Config.export_license))
|
||||
f.write("language: {}\n".format(Config.export_language))
|
||||
f.write("runtime: {}\n".format(model_runtime))
|
||||
f.write("min_stt_version: {}\n".format(FLAGS.export_min_stt_version))
|
||||
f.write("max_stt_version: {}\n".format(FLAGS.export_max_stt_version))
|
||||
f.write("min_stt_version: {}\n".format(Config.export_min_stt_version))
|
||||
f.write("max_stt_version: {}\n".format(Config.export_max_stt_version))
|
||||
f.write(
|
||||
"acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n"
|
||||
)
|
||||
|
@ -1143,7 +1147,7 @@ def export():
|
|||
"scorer_url: <replace this with a publicly available URL of the scorer, if present>\n"
|
||||
)
|
||||
f.write("---\n")
|
||||
f.write("{}\n".format(FLAGS.export_description))
|
||||
f.write("{}\n".format(Config.export_description))
|
||||
|
||||
log_info(
|
||||
"Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.".format(
|
||||
|
@ -1154,7 +1158,9 @@ def export():
|
|||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), "") # Force ending '/'
|
||||
export_dir = os.path.join(
|
||||
os.path.abspath(Config.export_dir), ""
|
||||
) # Force ending '/'
|
||||
if is_remote_path(export_dir):
|
||||
log_error(
|
||||
"Cannot package remote path zip %s. Please do this manually." % export_dir
|
||||
|
@ -1163,7 +1169,7 @@ def package_zip():
|
|||
|
||||
zip_filename = os.path.dirname(export_dir)
|
||||
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
shutil.copy(Config.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, "zip", export_dir)
|
||||
log_info("Exported packaged model {}".format(archive))
|
||||
|
@ -1200,19 +1206,19 @@ def do_single_file_inference(input_file_path):
|
|||
|
||||
probs = np.squeeze(probs)
|
||||
|
||||
if FLAGS.scorer_path:
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(
|
||||
probs,
|
||||
Config.alphabet,
|
||||
FLAGS.beam_width,
|
||||
Config.beam_width,
|
||||
scorer=scorer,
|
||||
cutoff_prob=FLAGS.cutoff_prob,
|
||||
cutoff_top_n=FLAGS.cutoff_top_n,
|
||||
cutoff_prob=Config.cutoff_prob,
|
||||
cutoff_top_n=Config.cutoff_top_n,
|
||||
)
|
||||
# Print highest probability result
|
||||
print(decoded[0][1])
|
||||
|
@ -1220,16 +1226,16 @@ def do_single_file_inference(input_file_path):
|
|||
|
||||
def early_training_checks():
|
||||
# Check for proper scorer early
|
||||
if FLAGS.scorer_path:
|
||||
if Config.scorer_path:
|
||||
scorer = Scorer(
|
||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
del scorer
|
||||
|
||||
if (
|
||||
FLAGS.train_files
|
||||
and FLAGS.test_files
|
||||
and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir
|
||||
Config.train_files
|
||||
and Config.test_files
|
||||
and Config.load_checkpoint_dir != Config.save_checkpoint_dir
|
||||
):
|
||||
log_warn(
|
||||
"WARNING: You specified different values for --load_checkpoint_dir "
|
||||
|
@ -1242,45 +1248,40 @@ def early_training_checks():
|
|||
)
|
||||
|
||||
|
||||
def main(_):
|
||||
def main():
|
||||
initialize_globals()
|
||||
early_training_checks()
|
||||
|
||||
if FLAGS.train_files:
|
||||
if Config.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(FLAGS.random_seed)
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test_files:
|
||||
if Config.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
||||
if Config.export_dir and not Config.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if FLAGS.export_zip:
|
||||
if Config.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
FLAGS.export_tflite = True
|
||||
Config.export_tflite = True
|
||||
|
||||
if listdir_remote(FLAGS.export_dir):
|
||||
if listdir_remote(Config.export_dir):
|
||||
log_error(
|
||||
"Directory {} is not empty, please fix this.".format(FLAGS.export_dir)
|
||||
"Directory {} is not empty, please fix this.".format(Config.export_dir)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
export()
|
||||
package_zip()
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
if Config.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
def run_script():
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
do_single_file_inference(Config.one_shot_infer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_script()
|
||||
main()
|
||||
|
|
|
@ -15,6 +15,7 @@ from .audio import (
|
|||
max_dbfs,
|
||||
normalize_audio,
|
||||
)
|
||||
from .config import log_info
|
||||
from .helpers import (
|
||||
MEGABYTE,
|
||||
LimitingPool,
|
||||
|
@ -23,7 +24,6 @@ from .helpers import (
|
|||
pick_value_from_range,
|
||||
tf_pick_value_from_range,
|
||||
)
|
||||
from .logging import log_info
|
||||
from .sample_collections import samples_from_source, unpack_maybe
|
||||
|
||||
BUFFER_SIZE = 1 * MEGABYTE
|
||||
|
@ -81,7 +81,7 @@ class GraphAugmentation(Augmentation):
|
|||
return (
|
||||
FLAGS.audio_sample_rate / 1000.0
|
||||
if self.domain == "signal"
|
||||
else 1.0 / FLAGS.feature_win_step
|
||||
else 1.0 / Config.feature_win_step
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,7 @@ import sys
|
|||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from .flags import FLAGS
|
||||
from .logging import log_error, log_info, log_warn
|
||||
from .config import Config, log_error, log_info, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
|
||||
|
@ -21,13 +20,13 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
|||
lr_var = set(v for v in load_vars if v.op.name == "learning_rate")
|
||||
if lr_var and (
|
||||
"learning_rate" not in vars_in_ckpt
|
||||
or (FLAGS.force_initialize_learning_rate and allow_lr_init)
|
||||
or (Config.force_initialize_learning_rate and allow_lr_init)
|
||||
):
|
||||
assert len(lr_var) <= 1
|
||||
load_vars -= lr_var
|
||||
init_vars |= lr_var
|
||||
|
||||
if FLAGS.load_cudnn:
|
||||
if Config.load_cudnn:
|
||||
# Initialize training from a CuDNN RNN checkpoint
|
||||
# Identify the variables which we cannot load, and set them
|
||||
# for initialization
|
||||
|
@ -51,23 +50,23 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
|||
)
|
||||
sys.exit(1)
|
||||
|
||||
if allow_drop_layers and FLAGS.drop_source_layers > 0:
|
||||
if allow_drop_layers and Config.drop_source_layers > 0:
|
||||
# This transfer learning approach requires supplying
|
||||
# the layers which we exclude from the source model.
|
||||
# Say we want to exclude all layers except for the first one,
|
||||
# then we are dropping five layers total, so: drop_source_layers=5
|
||||
# If we want to use all layers from the source model except
|
||||
# the last one, we use this: drop_source_layers=1
|
||||
if FLAGS.drop_source_layers >= 6:
|
||||
if Config.drop_source_layers >= 6:
|
||||
log_warn(
|
||||
"The checkpoint only has 6 layers, but you are trying to drop "
|
||||
"all of them or more than all of them. Continuing and "
|
||||
"dropping only 5 layers."
|
||||
)
|
||||
FLAGS.drop_source_layers = 5
|
||||
Config.drop_source_layers = 5
|
||||
|
||||
dropped_layers = ["2", "3", "lstm", "5", "6"][
|
||||
-1 * int(FLAGS.drop_source_layers) :
|
||||
-1 * int(Config.drop_source_layers) :
|
||||
]
|
||||
# Initialize all variables needed for DS, but not loaded from ckpt
|
||||
for v in load_vars:
|
||||
|
@ -86,7 +85,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
|
|||
|
||||
def _checkpoint_path_or_none(checkpoint_filename):
|
||||
checkpoint = tfv1.train.get_checkpoint_state(
|
||||
FLAGS.load_checkpoint_dir, checkpoint_filename
|
||||
Config.load_checkpoint_dir, checkpoint_filename
|
||||
)
|
||||
if not checkpoint:
|
||||
return None
|
||||
|
@ -145,10 +144,10 @@ def load_or_init_graph_for_training(session):
|
|||
and finally initialize the weights from scratch. This can be overriden with
|
||||
the `--load_train` flag. See its documentation for more info.
|
||||
"""
|
||||
if FLAGS.load_train == "auto":
|
||||
if Config.load_train == "auto":
|
||||
methods = ["best", "last", "init"]
|
||||
else:
|
||||
methods = [FLAGS.load_train]
|
||||
methods = [Config.load_train]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
||||
|
||||
|
||||
|
@ -159,8 +158,8 @@ def load_graph_for_evaluation(session):
|
|||
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
|
||||
documentation for more info.
|
||||
"""
|
||||
if FLAGS.load_evaluate == "auto":
|
||||
if Config.load_evaluate == "auto":
|
||||
methods = ["best", "last"]
|
||||
else:
|
||||
methods = [FLAGS.load_evaluate]
|
||||
methods = [Config.load_evaluate]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False)
|
||||
|
|
|
@ -2,92 +2,607 @@ from __future__ import absolute_import, division, print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List
|
||||
|
||||
import progressbar
|
||||
from attrdict import AttrDict
|
||||
from coqpit import MISSING, Coqpit, check_argument
|
||||
from coqui_stt_ctcdecoder import Alphabet, UTF8Alphabet
|
||||
from xdg import BaseDirectory as xdg
|
||||
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from .augmentations import NormalizeSampleRate, parse_augmentations
|
||||
from .flags import FLAGS
|
||||
from .gpu import get_available_gpus
|
||||
from .helpers import parse_file_size
|
||||
from .io import path_exists_remote
|
||||
from .logging import log_error, log_warn
|
||||
|
||||
|
||||
class ConfigSingleton:
|
||||
class _ConfigSingleton:
|
||||
_config = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
if not ConfigSingleton._config:
|
||||
if not _ConfigSingleton._config:
|
||||
raise RuntimeError("Global configuration not yet initialized.")
|
||||
if not hasattr(ConfigSingleton._config, name):
|
||||
if not hasattr(_ConfigSingleton._config, name):
|
||||
raise RuntimeError(
|
||||
"Configuration option {} not found in config.".format(name)
|
||||
)
|
||||
return ConfigSingleton._config[name]
|
||||
return getattr(_ConfigSingleton._config, name)
|
||||
|
||||
|
||||
Config = ConfigSingleton() # pylint: disable=invalid-name
|
||||
Config = _ConfigSingleton() # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SttConfig(Coqpit):
|
||||
train_files: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata=dict(
|
||||
help="space-separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run."
|
||||
),
|
||||
)
|
||||
dev_files: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata=dict(
|
||||
help="space-separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run."
|
||||
),
|
||||
)
|
||||
test_files: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata=dict(
|
||||
help="space-separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested."
|
||||
),
|
||||
)
|
||||
metrics_files: List[str] = field(
|
||||
default_factory=list,
|
||||
metadata=dict(
|
||||
help="space-separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed."
|
||||
),
|
||||
)
|
||||
|
||||
read_buffer: str = field(
|
||||
default="1MB",
|
||||
metadata=dict(
|
||||
help="buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)"
|
||||
),
|
||||
)
|
||||
feature_cache: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled."
|
||||
),
|
||||
)
|
||||
cache_for_epochs: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help='after how many epochs the feature cache is invalidated again - 0 for "never"'
|
||||
),
|
||||
)
|
||||
|
||||
feature_win_len: int = field(
|
||||
default=32,
|
||||
metadata=dict(help="feature extraction audio window length in milliseconds"),
|
||||
)
|
||||
feature_win_step: int = field(
|
||||
default=20,
|
||||
metadata=dict(help="feature extraction window step length in milliseconds"),
|
||||
)
|
||||
audio_sample_rate: int = field(
|
||||
default=16000, metadata=dict(help="sample rate value expected by model")
|
||||
)
|
||||
normalize_sample_rate: bool = field(
|
||||
default=True,
|
||||
metadata=dict(
|
||||
help="normalize sample rate of all train_files to --audio_sample_rate"
|
||||
),
|
||||
)
|
||||
|
||||
# Data Augmentation
|
||||
augment: List[str] = field(
|
||||
default=None,
|
||||
metadata=dict(
|
||||
help='space-separated list of augmenations for training samples. Format is "--augment operation1[param1=value1, ...] operation2[param1=value1, ...] ..."'
|
||||
),
|
||||
)
|
||||
|
||||
# Global Constants
|
||||
epochs: int = field(
|
||||
default=75,
|
||||
metadata=dict(
|
||||
help="how many epochs (complete runs through the train files) to train for"
|
||||
),
|
||||
)
|
||||
|
||||
dropout_rate: float = field(
|
||||
default=0.05, metadata=dict(help="dropout rate for feedforward layers")
|
||||
)
|
||||
dropout_rate2: float = field(
|
||||
default=-1.0,
|
||||
metadata=dict(help="dropout rate for layer 2 - defaults to dropout_rate"),
|
||||
)
|
||||
dropout_rate3: float = field(
|
||||
default=-1.0,
|
||||
metadata=dict(help="dropout rate for layer 3 - defaults to dropout_rate"),
|
||||
)
|
||||
dropout_rate4: float = field(
|
||||
default=0.0, metadata=dict(help="dropout rate for layer 4 - defaults to 0.0")
|
||||
)
|
||||
dropout_rate5: float = field(
|
||||
default=0.0, metadata=dict(help="dropout rate for layer 5 - defaults to 0.0")
|
||||
)
|
||||
dropout_rate6: float = field(
|
||||
default=-1.0,
|
||||
metadata=dict(help="dropout rate for layer 6 - defaults to dropout_rate"),
|
||||
)
|
||||
|
||||
relu_clip: float = field(
|
||||
default=20.0, metadata=dict(help="ReLU clipping value for non-recurrent layers")
|
||||
)
|
||||
|
||||
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
|
||||
beta1: float = field(
|
||||
default=0.9, metadata=dict(help="beta 1 parameter of Adam optimizer")
|
||||
)
|
||||
beta2: float = field(
|
||||
default=0.999, metadata=dict(help="beta 2 parameter of Adam optimizer")
|
||||
)
|
||||
epsilon: float = field(
|
||||
default=1e-8, metadata=dict(help="epsilon parameter of Adam optimizer")
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=0.001, metadata=dict(help="learning rate of Adam optimizer")
|
||||
)
|
||||
|
||||
# Batch sizes
|
||||
train_batch_size: int = field(
|
||||
default=1, metadata=dict(help="number of elements in a training batch")
|
||||
)
|
||||
dev_batch_size: int = field(
|
||||
default=1, metadata=dict(help="number of elements in a validation batch")
|
||||
)
|
||||
test_batch_size: int = field(
|
||||
default=1, metadata=dict(help="number of elements in a test batch")
|
||||
)
|
||||
|
||||
export_batch_size: int = field(
|
||||
default=1,
|
||||
metadata=dict(help="number of elements per batch on the exported graph"),
|
||||
)
|
||||
|
||||
# Performance
|
||||
inter_op_parallelism_threads: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help="number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED"
|
||||
),
|
||||
)
|
||||
intra_op_parallelism_threads: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help="number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED"
|
||||
),
|
||||
)
|
||||
use_allow_growth: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory"
|
||||
),
|
||||
)
|
||||
load_cudnn: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph."
|
||||
),
|
||||
)
|
||||
train_cudnn: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work"
|
||||
),
|
||||
)
|
||||
automatic_mixed_precision: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision."
|
||||
),
|
||||
)
|
||||
|
||||
# Sample limits
|
||||
limit_train: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help="maximum number of elements to use from train set - 0 means no limit"
|
||||
),
|
||||
)
|
||||
limit_dev: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help="maximum number of elements to use from validation set - 0 means no limit"
|
||||
),
|
||||
)
|
||||
limit_test: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help="maximum number of elements to use from test set - 0 means no limit"
|
||||
),
|
||||
)
|
||||
|
||||
# Sample order
|
||||
reverse_train: bool = field(
|
||||
default=False, metadata=dict(help="if to reverse sample order of the train set")
|
||||
)
|
||||
reverse_dev: bool = field(
|
||||
default=False, metadata=dict(help="if to reverse sample order of the dev set")
|
||||
)
|
||||
reverse_test: bool = field(
|
||||
default=False, metadata=dict(help="if to reverse sample order of the test set")
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
checkpoint_dir: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help='directory from which checkpoints are loaded and to which they are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification'
|
||||
),
|
||||
)
|
||||
load_checkpoint_dir: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help='directory in which checkpoints are stored - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification'
|
||||
),
|
||||
)
|
||||
save_checkpoint_dir: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help='directory to which checkpoints are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification'
|
||||
),
|
||||
)
|
||||
checkpoint_secs: int = field(
|
||||
default=600, metadata=dict(help="checkpoint saving interval in seconds")
|
||||
)
|
||||
max_to_keep: int = field(
|
||||
default=5,
|
||||
metadata=dict(help="number of checkpoint files to keep - default value is 5"),
|
||||
)
|
||||
load_train: str = field(
|
||||
default="auto",
|
||||
metadata=dict(
|
||||
help='what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.'
|
||||
),
|
||||
)
|
||||
load_evaluate: str = field(
|
||||
default="auto",
|
||||
metadata=dict(
|
||||
help='what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.'
|
||||
),
|
||||
)
|
||||
|
||||
# Transfer Learning
|
||||
drop_source_layers: int = field(
|
||||
default=0,
|
||||
metadata=dict(
|
||||
help="single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)"
|
||||
),
|
||||
)
|
||||
|
||||
# Exporting
|
||||
export_dir: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="directory in which exported models are stored - if omitted, the model won't get exported"
|
||||
),
|
||||
)
|
||||
remove_export: bool = field(
|
||||
default=False, metadata=dict(help="whether to remove old exported models")
|
||||
)
|
||||
export_tflite: bool = field(
|
||||
default=False, metadata=dict(help="export a graph ready for TF Lite engine")
|
||||
)
|
||||
n_steps: int = field(
|
||||
default=16,
|
||||
metadata=dict(
|
||||
help="how many timesteps to process at once by the export graph, higher values mean more latency"
|
||||
),
|
||||
)
|
||||
export_zip: bool = field(
|
||||
default=False,
|
||||
metadata=dict(help="export a TFLite model and package with LM and info.json"),
|
||||
)
|
||||
export_file_name: str = field(
|
||||
default="output_graph",
|
||||
metadata=dict(help="name for the exported model file name"),
|
||||
)
|
||||
export_beam_width: int = field(
|
||||
default=500,
|
||||
metadata=dict(help="default beam width to embed into exported graph"),
|
||||
)
|
||||
|
||||
# Model metadata
|
||||
export_author_id: str = field(
|
||||
default="author",
|
||||
metadata=dict(
|
||||
help="author of the exported model. GitHub user or organization name used to uniquely identify the author of this model"
|
||||
),
|
||||
)
|
||||
export_model_name: str = field(
|
||||
default="model",
|
||||
metadata=dict(
|
||||
help="name of the exported model. Must not contain forward slashes."
|
||||
),
|
||||
)
|
||||
export_model_version: str = field(
|
||||
default="0.0.1",
|
||||
metadata=dict(
|
||||
help="semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with Coqui STT versions"
|
||||
),
|
||||
)
|
||||
|
||||
def field_val_equals_help(val_desc):
|
||||
return field(default="<{}>".format(val_desc), metadata=dict(help=val_desc))
|
||||
|
||||
export_contact_info: str = field_val_equals_help(
|
||||
"public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors"
|
||||
)
|
||||
export_license: str = field_val_equals_help(
|
||||
"SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name."
|
||||
)
|
||||
export_language: str = field_val_equals_help(
|
||||
'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".'
|
||||
)
|
||||
export_min_stt_version: str = field_val_equals_help(
|
||||
"minimum Coqui STT version (inclusive) the exported model is compatible with"
|
||||
)
|
||||
export_max_stt_version: str = field_val_equals_help(
|
||||
"maximum Coqui STT version (inclusive) the exported model is compatible with"
|
||||
)
|
||||
export_description: str = field_val_equals_help(
|
||||
"Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc."
|
||||
)
|
||||
|
||||
# Reporting
|
||||
log_level: int = field(
|
||||
default=1,
|
||||
metadata=dict(
|
||||
help="log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR"
|
||||
),
|
||||
)
|
||||
show_progressbar: bool = field(
|
||||
default=True,
|
||||
metadata=dict(
|
||||
help="Show progress for training, validation and testing processes. Log level should be > 0."
|
||||
),
|
||||
)
|
||||
|
||||
log_placement: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="whether to log device placement of the operators to the console"
|
||||
),
|
||||
)
|
||||
report_count: int = field(
|
||||
default=5,
|
||||
metadata=dict(
|
||||
help="number of phrases for each of best WER, median WER and worst WER to print out during a WER report"
|
||||
),
|
||||
)
|
||||
|
||||
summary_dir: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help='target directory for TensorBoard summaries - defaults to directory "stt/summaries" within user\'s data home specified by the XDG Base Directory Specification'
|
||||
),
|
||||
)
|
||||
|
||||
test_output_file: str = field(
|
||||
default="",
|
||||
metadata=dict(
|
||||
help="path to a file to save all src/decoded/distance/loss tuples generated during a test epoch"
|
||||
),
|
||||
)
|
||||
|
||||
# Geometry
|
||||
n_hidden: int = field(
|
||||
default=2048, metadata=dict(help="layer width to use when initialising layers")
|
||||
)
|
||||
layer_norm: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="wether to use layer-normalization after each fully-connected layer (except the last one)"
|
||||
),
|
||||
)
|
||||
|
||||
# Initialization
|
||||
random_seed: int = field(
|
||||
default=4568,
|
||||
metadata=dict(help="default random seed that is used to initialize variables"),
|
||||
)
|
||||
|
||||
# Early Stopping
|
||||
early_stop: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled."
|
||||
),
|
||||
)
|
||||
es_epochs: int = field(
|
||||
default=25,
|
||||
metadata=dict(
|
||||
help="Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point"
|
||||
),
|
||||
)
|
||||
es_min_delta: float = field(
|
||||
default=0.05,
|
||||
metadata=dict(
|
||||
help="Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau"
|
||||
),
|
||||
)
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
reduce_lr_on_plateau: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs."
|
||||
),
|
||||
)
|
||||
plateau_epochs: int = field(
|
||||
default=10,
|
||||
metadata=dict(
|
||||
help="Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping"
|
||||
),
|
||||
)
|
||||
plateau_reduction: float = field(
|
||||
default=0.1,
|
||||
metadata=dict(
|
||||
help="Multiplicative factor to apply to the current learning rate if a plateau has occurred."
|
||||
),
|
||||
)
|
||||
force_initialize_learning_rate: bool = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="Force re-initialization of learning rate which was previously reduced."
|
||||
),
|
||||
)
|
||||
|
||||
# Decoder
|
||||
bytes_output_mode: float = field(
|
||||
default=False,
|
||||
metadata=dict(
|
||||
help="enable Bytes Output Mode mode. When this is used the model outputs UTF-8 byte values directly rather than using an alphabet mapping. The --alphabet_config_path option will be ignored. See the training documentation for more details."
|
||||
),
|
||||
)
|
||||
alphabet_config_path: str = field(
|
||||
default="data/alphabet.txt",
|
||||
metadata=dict(
|
||||
help="path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format."
|
||||
),
|
||||
)
|
||||
scorer_path: str = field(
|
||||
default="", metadata=dict(help="path to the external scorer file.")
|
||||
)
|
||||
beam_width: int = field(
|
||||
default=1024,
|
||||
metadata=dict(
|
||||
help="beam width used in the CTC decoder when building candidate transcriptions"
|
||||
),
|
||||
)
|
||||
# TODO move these defaults into some sort of external (inheritable?) configuration
|
||||
lm_alpha: float = field(
|
||||
default=0.931289039105002,
|
||||
metadata=dict(
|
||||
help="the alpha hyperparameter of the CTC decoder. Language Model weight."
|
||||
),
|
||||
)
|
||||
lm_beta: float = field(
|
||||
default=1.1834137581510284,
|
||||
metadata=dict(
|
||||
help="the beta hyperparameter of the CTC decoder. Word insertion weight."
|
||||
),
|
||||
)
|
||||
cutoff_prob: float = field(
|
||||
default=1.0,
|
||||
metadata=dict(
|
||||
help="only consider characters until this probability mass is reached. 1.0 = disabled."
|
||||
),
|
||||
)
|
||||
cutoff_top_n: int = field(
|
||||
default=300,
|
||||
metadata=dict(
|
||||
help="only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled."
|
||||
),
|
||||
)
|
||||
|
||||
# Inference mode
|
||||
one_shot_infer: str = field(
|
||||
default=None,
|
||||
metadata=dict(
|
||||
help="one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it."
|
||||
),
|
||||
)
|
||||
|
||||
# Optimizer mode
|
||||
lm_alpha_max: int = field(
|
||||
default=5,
|
||||
metadata=dict(
|
||||
help="the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight."
|
||||
),
|
||||
)
|
||||
lm_beta_max: int = field(
|
||||
default=5,
|
||||
metadata=dict(
|
||||
help="the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight."
|
||||
),
|
||||
)
|
||||
n_trials: int = field(
|
||||
default=2400,
|
||||
metadata=dict(
|
||||
help="the number of trials to run during hyperparameter optimization."
|
||||
),
|
||||
)
|
||||
|
||||
def check_values(self):
|
||||
c = asdict(self)
|
||||
check_argument("alphabet_config_path", c, is_path=True)
|
||||
check_argument("one_shot_infer", c, is_path=True)
|
||||
|
||||
|
||||
def initialize_globals():
|
||||
c = AttrDict()
|
||||
c = _SttConfig()
|
||||
c.parse_args(arg_prefix="")
|
||||
|
||||
# Augmentations
|
||||
c.augmentations = parse_augmentations(FLAGS.augment)
|
||||
if c.augmentations and FLAGS.feature_cache and FLAGS.cache_for_epochs == 0:
|
||||
c.augmentations = parse_augmentations(c.augment)
|
||||
if c.augmentations and c.feature_cache and c.cache_for_epochs == 0:
|
||||
log_warn(
|
||||
"Due to current feature-cache settings the exact same sample augmentations of the first "
|
||||
"epoch will be repeated on all following epochs. This could lead to unintended over-fitting. "
|
||||
"You could use --cache_for_epochs <n_epochs> to invalidate the cache after a given number of epochs."
|
||||
)
|
||||
|
||||
if FLAGS.normalize_sample_rate:
|
||||
c.augmentations = [NormalizeSampleRate(FLAGS.audio_sample_rate)] + c[
|
||||
if c.normalize_sample_rate:
|
||||
c.augmentations = [NormalizeSampleRate(c.audio_sample_rate)] + c[
|
||||
"augmentations"
|
||||
]
|
||||
|
||||
# Caching
|
||||
if FLAGS.cache_for_epochs == 1:
|
||||
if c.cache_for_epochs == 1:
|
||||
log_warn(
|
||||
"--cache_for_epochs == 1 is (re-)creating the feature cache on every epoch but will never use it."
|
||||
)
|
||||
|
||||
# Read-buffer
|
||||
FLAGS.read_buffer = parse_file_size(FLAGS.read_buffer)
|
||||
c.read_buffer = parse_file_size(c.read_buffer)
|
||||
|
||||
# Set default dropout rates
|
||||
if FLAGS.dropout_rate2 < 0:
|
||||
FLAGS.dropout_rate2 = FLAGS.dropout_rate
|
||||
if FLAGS.dropout_rate3 < 0:
|
||||
FLAGS.dropout_rate3 = FLAGS.dropout_rate
|
||||
if FLAGS.dropout_rate6 < 0:
|
||||
FLAGS.dropout_rate6 = FLAGS.dropout_rate
|
||||
if c.dropout_rate2 < 0:
|
||||
c.dropout_rate2 = c.dropout_rate
|
||||
if c.dropout_rate3 < 0:
|
||||
c.dropout_rate3 = c.dropout_rate
|
||||
if c.dropout_rate6 < 0:
|
||||
c.dropout_rate6 = c.dropout_rate
|
||||
|
||||
# Set default checkpoint dir
|
||||
if not FLAGS.checkpoint_dir:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
|
||||
if not c.checkpoint_dir:
|
||||
c.checkpoint_dir = xdg.save_data_path(os.path.join("stt", "checkpoints"))
|
||||
|
||||
if FLAGS.load_train not in ["last", "best", "init", "auto"]:
|
||||
FLAGS.load_train = "auto"
|
||||
if c.load_train not in ["last", "best", "init", "auto"]:
|
||||
c.load_train = "auto"
|
||||
|
||||
if FLAGS.load_evaluate not in ["last", "best", "auto"]:
|
||||
FLAGS.load_evaluate = "auto"
|
||||
if c.load_evaluate not in ["last", "best", "auto"]:
|
||||
c.load_evaluate = "auto"
|
||||
|
||||
# Set default summary dir
|
||||
if not FLAGS.summary_dir:
|
||||
FLAGS.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
||||
if not c.summary_dir:
|
||||
c.summary_dir = xdg.save_data_path(os.path.join("stt", "summaries"))
|
||||
|
||||
# Standard session configuration that'll be used for all new sessions.
|
||||
c.session_config = tfv1.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=FLAGS.log_placement,
|
||||
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
|
||||
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
|
||||
gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth),
|
||||
log_device_placement=c.log_placement,
|
||||
inter_op_parallelism_threads=c.inter_op_parallelism_threads,
|
||||
intra_op_parallelism_threads=c.intra_op_parallelism_threads,
|
||||
gpu_options=tfv1.GPUOptions(allow_growth=c.use_allow_growth),
|
||||
)
|
||||
|
||||
# CPU device
|
||||
|
@ -100,10 +615,10 @@ def initialize_globals():
|
|||
if not c.available_devices:
|
||||
c.available_devices = [c.cpu_device]
|
||||
|
||||
if FLAGS.bytes_output_mode:
|
||||
if c.bytes_output_mode:
|
||||
c.alphabet = UTF8Alphabet()
|
||||
else:
|
||||
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
|
||||
c.alphabet = Alphabet(os.path.abspath(c.alphabet_config_path))
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
@ -118,7 +633,7 @@ def initialize_globals():
|
|||
c.n_context = 9 # TODO: Determine the optimal value using a validation data set
|
||||
|
||||
# Number of units in hidden layers
|
||||
c.n_hidden = FLAGS.n_hidden
|
||||
c.n_hidden = c.n_hidden
|
||||
|
||||
c.n_hidden_1 = c.n_hidden
|
||||
|
||||
|
@ -136,43 +651,37 @@ def initialize_globals():
|
|||
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
||||
|
||||
# Size of audio window in samples
|
||||
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
|
||||
if (c.feature_win_len * c.audio_sample_rate) % 1000 != 0:
|
||||
log_error(
|
||||
"--feature_win_len value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_len value or resample your audio accordingly."
|
||||
"".format(
|
||||
FLAGS.feature_win_len,
|
||||
FLAGS.feature_win_len / 1000,
|
||||
FLAGS.audio_sample_rate,
|
||||
)
|
||||
"".format(c.feature_win_len, c.feature_win_len / 1000, c.audio_sample_rate)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)
|
||||
c.audio_window_samples = c.audio_sample_rate * (c.feature_win_len / 1000)
|
||||
|
||||
# Stride for feature computations in samples
|
||||
if (FLAGS.feature_win_step * FLAGS.audio_sample_rate) % 1000 != 0:
|
||||
if (c.feature_win_step * c.audio_sample_rate) % 1000 != 0:
|
||||
log_error(
|
||||
"--feature_win_step value ({}) in milliseconds ({}) multiplied "
|
||||
"by --audio_sample_rate value ({}) must be an integer value. Adjust "
|
||||
"your --feature_win_step value or resample your audio accordingly."
|
||||
"".format(
|
||||
FLAGS.feature_win_step,
|
||||
FLAGS.feature_win_step / 1000,
|
||||
FLAGS.audio_sample_rate,
|
||||
c.feature_win_step, c.feature_win_step / 1000, c.audio_sample_rate
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
|
||||
c.audio_step_samples = c.audio_sample_rate * (c.feature_win_step / 1000)
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
if not path_exists_remote(FLAGS.one_shot_infer):
|
||||
if c.one_shot_infer:
|
||||
if not path_exists_remote(c.one_shot_infer):
|
||||
log_error("Path specified in --one_shot_infer is not a valid file.")
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.train_cudnn and FLAGS.load_cudnn:
|
||||
if c.train_cudnn and c.load_cudnn:
|
||||
log_error(
|
||||
"Trying to use --train_cudnn, but --load_cudnn "
|
||||
"was also specified. The --load_cudnn flag is only "
|
||||
|
@ -185,10 +694,54 @@ def initialize_globals():
|
|||
|
||||
# If separate save and load flags were not specified, default to load and save
|
||||
# from the same dir.
|
||||
if not FLAGS.save_checkpoint_dir:
|
||||
FLAGS.save_checkpoint_dir = FLAGS.checkpoint_dir
|
||||
if not c.save_checkpoint_dir:
|
||||
c.save_checkpoint_dir = c.checkpoint_dir
|
||||
|
||||
if not FLAGS.load_checkpoint_dir:
|
||||
FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir
|
||||
if not c.load_checkpoint_dir:
|
||||
c.load_checkpoint_dir = c.checkpoint_dir
|
||||
|
||||
ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
_ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
||||
|
||||
# Logging functions
|
||||
# =================
|
||||
|
||||
|
||||
def prefix_print(prefix, message):
|
||||
print(prefix + ("\n" + prefix).join(message.split("\n")))
|
||||
|
||||
|
||||
def log_debug(message):
|
||||
if Config.log_level == 0:
|
||||
prefix_print("D ", message)
|
||||
|
||||
|
||||
def log_info(message):
|
||||
if Config.log_level <= 1:
|
||||
prefix_print("I ", message)
|
||||
|
||||
|
||||
def log_warn(message):
|
||||
if Config.log_level <= 2:
|
||||
prefix_print("W ", message)
|
||||
|
||||
|
||||
def log_error(message):
|
||||
if Config.log_level <= 3:
|
||||
prefix_print("E ", message)
|
||||
|
||||
|
||||
def create_progressbar(*args, **kwargs):
|
||||
# Progress bars in stdout by default
|
||||
if "fd" not in kwargs:
|
||||
kwargs["fd"] = sys.stdout
|
||||
|
||||
if Config.show_progressbar:
|
||||
return progressbar.ProgressBar(*args, **kwargs)
|
||||
|
||||
return progressbar.NullBar(*args, **kwargs)
|
||||
|
||||
|
||||
def log_progress(message):
|
||||
if not Config.show_progressbar:
|
||||
log_info(message)
|
||||
|
|
|
@ -8,7 +8,7 @@ from multiprocessing.dummy import Pool
|
|||
import numpy as np
|
||||
from attrdict import AttrDict
|
||||
|
||||
from .flags import FLAGS
|
||||
from .config import Config
|
||||
from .io import open_remote
|
||||
from .text import levenshtein
|
||||
|
||||
|
@ -75,7 +75,7 @@ def calculate_and_print_report(wav_filenames, labels, decodings, losses, dataset
|
|||
samples.sort(key=lambda s: s.loss, reverse=True)
|
||||
|
||||
# Then order by ascending WER/CER
|
||||
if FLAGS.bytes_output_mode:
|
||||
if Config.bytes_output_mode:
|
||||
samples.sort(key=lambda s: s.cer)
|
||||
else:
|
||||
samples.sort(key=lambda s: s.wer)
|
||||
|
@ -96,11 +96,11 @@ def print_report(samples, losses, wer, cer, dataset_name):
|
|||
)
|
||||
print("-" * 80)
|
||||
|
||||
best_samples = samples[: FLAGS.report_count]
|
||||
worst_samples = samples[-FLAGS.report_count :]
|
||||
best_samples = samples[: Config.report_count]
|
||||
worst_samples = samples[-Config.report_count :]
|
||||
median_index = int(len(samples) / 2)
|
||||
median_left = int(FLAGS.report_count / 2)
|
||||
median_right = FLAGS.report_count - median_left
|
||||
median_left = int(Config.report_count / 2)
|
||||
median_right = Config.report_count - median_left
|
||||
median_samples = samples[median_index - median_left : median_index + median_right]
|
||||
|
||||
def print_single_sample(sample):
|
||||
|
|
|
@ -12,7 +12,6 @@ from tensorflow.python.ops import gen_audio_ops as contrib_audio
|
|||
from .audio import DEFAULT_FORMAT, pcm_to_np, read_frames_from_file, vad_split
|
||||
from .augmentations import apply_graph_augmentations, apply_sample_augmentations
|
||||
from .config import Config
|
||||
from .flags import FLAGS
|
||||
from .helpers import MEGABYTE, remember_exception
|
||||
from .sample_collections import samples_from_sources
|
||||
from .text import text_to_char_array
|
||||
|
@ -31,14 +30,14 @@ def audio_to_features(
|
|||
# We need the lambdas to make TensorFlow happy.
|
||||
# pylint: disable=unnecessary-lambda
|
||||
tf.cond(
|
||||
tf.math.not_equal(sample_rate, FLAGS.audio_sample_rate),
|
||||
tf.math.not_equal(sample_rate, Config.audio_sample_rate),
|
||||
lambda: tf.print(
|
||||
"WARNING: sample rate of sample",
|
||||
sample_id,
|
||||
"(",
|
||||
sample_rate,
|
||||
") "
|
||||
"does not match FLAGS.audio_sample_rate. This can lead to incorrect results.",
|
||||
"does not match Config.audio_sample_rate. This can lead to incorrect results.",
|
||||
),
|
||||
lambda: tf.no_op(),
|
||||
name="matching_sample_rate",
|
||||
|
@ -69,7 +68,7 @@ def audio_to_features(
|
|||
spectrogram=spectrogram,
|
||||
sample_rate=sample_rate,
|
||||
dct_coefficient_count=Config.n_input,
|
||||
upper_frequency_limit=FLAGS.audio_sample_rate / 2,
|
||||
upper_frequency_limit=Config.audio_sample_rate / 2,
|
||||
)
|
||||
features = tf.reshape(features, [-1, Config.n_input])
|
||||
|
||||
|
|
|
@ -1,460 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
|
||||
import absl.flags
|
||||
|
||||
FLAGS = absl.flags.FLAGS
|
||||
|
||||
# sphinx-doc: training_ref_flags_start
|
||||
def create_flags():
|
||||
# Importer
|
||||
# ========
|
||||
|
||||
f = absl.flags
|
||||
|
||||
f.DEFINE_string(
|
||||
"train_files",
|
||||
"",
|
||||
"comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"dev_files",
|
||||
"",
|
||||
"comma separated list of files specifying the datasets used for validation. Multiple files will get reported separately. If empty, validation will not be run.",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"test_files",
|
||||
"",
|
||||
"comma separated list of files specifying the datasets used for testing. Multiple files will get reported separately. If empty, the model will not be tested.",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"metrics_files",
|
||||
"",
|
||||
"comma separated list of files specifying the datasets used for tracking of metrics (after validation step). Currently the only metric is the CTC loss but without affecting the tracking of best validation loss. Multiple files will get reported separately. If empty, metrics will not be computed.",
|
||||
)
|
||||
|
||||
f.DEFINE_string(
|
||||
"read_buffer",
|
||||
"1MB",
|
||||
"buffer-size for reading samples from datasets (supports file-size suffixes KB, MB, GB, TB)",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"feature_cache",
|
||||
"",
|
||||
"cache MFCC features to disk to speed up future training runs on the same data. This flag specifies the path where cached features extracted from --train_files will be saved. If empty, or if online augmentation flags are enabled, caching will be disabled.",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"cache_for_epochs",
|
||||
0,
|
||||
'after how many epochs the feature cache is invalidated again - 0 for "never"',
|
||||
)
|
||||
|
||||
f.DEFINE_integer(
|
||||
"feature_win_len", 32, "feature extraction audio window length in milliseconds"
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"feature_win_step", 20, "feature extraction window step length in milliseconds"
|
||||
)
|
||||
f.DEFINE_integer("audio_sample_rate", 16000, "sample rate value expected by model")
|
||||
f.DEFINE_boolean(
|
||||
"normalize_sample_rate",
|
||||
True,
|
||||
"normalize sample rate of all train_files to --audio_sample_rate",
|
||||
)
|
||||
|
||||
# Data Augmentation
|
||||
# ================
|
||||
|
||||
f.DEFINE_multi_string(
|
||||
"augment",
|
||||
None,
|
||||
'specifies an augmentation of the training samples. Format is "--augment operation[param1=value1, ...]"',
|
||||
)
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
f.DEFINE_integer(
|
||||
"epochs",
|
||||
75,
|
||||
"how many epochs (complete runs through the train files) to train for",
|
||||
)
|
||||
|
||||
f.DEFINE_float("dropout_rate", 0.05, "dropout rate for feedforward layers")
|
||||
f.DEFINE_float(
|
||||
"dropout_rate2", -1.0, "dropout rate for layer 2 - defaults to dropout_rate"
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"dropout_rate3", -1.0, "dropout rate for layer 3 - defaults to dropout_rate"
|
||||
)
|
||||
f.DEFINE_float("dropout_rate4", 0.0, "dropout rate for layer 4 - defaults to 0.0")
|
||||
f.DEFINE_float("dropout_rate5", 0.0, "dropout rate for layer 5 - defaults to 0.0")
|
||||
f.DEFINE_float(
|
||||
"dropout_rate6", -1.0, "dropout rate for layer 6 - defaults to dropout_rate"
|
||||
)
|
||||
|
||||
f.DEFINE_float("relu_clip", 20.0, "ReLU clipping value for non-recurrent layers")
|
||||
|
||||
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
|
||||
|
||||
f.DEFINE_float("beta1", 0.9, "beta 1 parameter of Adam optimizer")
|
||||
f.DEFINE_float("beta2", 0.999, "beta 2 parameter of Adam optimizer")
|
||||
f.DEFINE_float("epsilon", 1e-8, "epsilon parameter of Adam optimizer")
|
||||
f.DEFINE_float("learning_rate", 0.001, "learning rate of Adam optimizer")
|
||||
|
||||
# Batch sizes
|
||||
|
||||
f.DEFINE_integer("train_batch_size", 1, "number of elements in a training batch")
|
||||
f.DEFINE_integer("dev_batch_size", 1, "number of elements in a validation batch")
|
||||
f.DEFINE_integer("test_batch_size", 1, "number of elements in a test batch")
|
||||
|
||||
f.DEFINE_integer(
|
||||
"export_batch_size", 1, "number of elements per batch on the exported graph"
|
||||
)
|
||||
|
||||
# Performance
|
||||
|
||||
f.DEFINE_integer(
|
||||
"inter_op_parallelism_threads",
|
||||
0,
|
||||
"number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"intra_op_parallelism_threads",
|
||||
0,
|
||||
"number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"use_allow_growth",
|
||||
False,
|
||||
"use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"load_cudnn",
|
||||
False,
|
||||
"Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"train_cudnn",
|
||||
False,
|
||||
"use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"automatic_mixed_precision",
|
||||
False,
|
||||
"whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.",
|
||||
)
|
||||
|
||||
# Sample limits
|
||||
|
||||
f.DEFINE_integer(
|
||||
"limit_train",
|
||||
0,
|
||||
"maximum number of elements to use from train set - 0 means no limit",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"limit_dev",
|
||||
0,
|
||||
"maximum number of elements to use from validation set - 0 means no limit",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"limit_test",
|
||||
0,
|
||||
"maximum number of elements to use from test set - 0 means no limit",
|
||||
)
|
||||
|
||||
# Sample order
|
||||
|
||||
f.DEFINE_boolean(
|
||||
"reverse_train", False, "if to reverse sample order of the train set"
|
||||
)
|
||||
f.DEFINE_boolean("reverse_dev", False, "if to reverse sample order of the dev set")
|
||||
f.DEFINE_boolean(
|
||||
"reverse_test", False, "if to reverse sample order of the test set"
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
|
||||
f.DEFINE_string(
|
||||
"checkpoint_dir",
|
||||
"",
|
||||
'directory from which checkpoints are loaded and to which they are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification',
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"load_checkpoint_dir",
|
||||
"",
|
||||
'directory in which checkpoints are stored - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification',
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"save_checkpoint_dir",
|
||||
"",
|
||||
'directory to which checkpoints are saved - defaults to directory "stt/checkpoints" within user\'s data home specified by the XDG Base Directory Specification',
|
||||
)
|
||||
f.DEFINE_integer("checkpoint_secs", 600, "checkpoint saving interval in seconds")
|
||||
f.DEFINE_integer(
|
||||
"max_to_keep", 5, "number of checkpoint files to keep - default value is 5"
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"load_train",
|
||||
"auto",
|
||||
'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.',
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"load_evaluate",
|
||||
"auto",
|
||||
'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.',
|
||||
)
|
||||
|
||||
# Transfer Learning
|
||||
|
||||
f.DEFINE_integer(
|
||||
"drop_source_layers",
|
||||
0,
|
||||
"single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)",
|
||||
)
|
||||
|
||||
# Exporting
|
||||
|
||||
f.DEFINE_string(
|
||||
"export_dir",
|
||||
"",
|
||||
"directory in which exported models are stored - if omitted, the model won't get exported",
|
||||
)
|
||||
f.DEFINE_boolean("remove_export", False, "whether to remove old exported models")
|
||||
f.DEFINE_boolean("export_tflite", False, "export a graph ready for TF Lite engine")
|
||||
f.DEFINE_integer(
|
||||
"n_steps",
|
||||
16,
|
||||
"how many timesteps to process at once by the export graph, higher values mean more latency",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"export_zip", False, "export a TFLite model and package with LM and info.json"
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"export_file_name", "output_graph", "name for the exported model file name"
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"export_beam_width", 500, "default beam width to embed into exported graph"
|
||||
)
|
||||
|
||||
# Model metadata
|
||||
|
||||
f.DEFINE_string(
|
||||
"export_author_id",
|
||||
"author",
|
||||
"author of the exported model. GitHub user or organization name used to uniquely identify the author of this model",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"export_model_name",
|
||||
"model",
|
||||
"name of the exported model. Must not contain forward slashes.",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"export_model_version",
|
||||
"0.0.1",
|
||||
"semantic version of the exported model. See https://semver.org/. This is fully controlled by you as author of the model and has no required connection with Coqui STT versions",
|
||||
)
|
||||
|
||||
def str_val_equals_help(name, val_desc):
|
||||
f.DEFINE_string(name, "<{}>".format(val_desc), val_desc)
|
||||
|
||||
str_val_equals_help(
|
||||
"export_contact_info",
|
||||
"public contact information of the author. Can be an email address, or a link to a contact form, issue tracker, or discussion forum. Must provide a way to reach the model authors",
|
||||
)
|
||||
str_val_equals_help(
|
||||
"export_license",
|
||||
"SPDX identifier of the license of the exported model. See https://spdx.org/licenses/. If the license does not have an SPDX identifier, use the license name.",
|
||||
)
|
||||
str_val_equals_help(
|
||||
"export_language",
|
||||
'language the model was trained on - IETF BCP 47 language tag including at least language, script and region subtags. E.g. "en-Latn-UK" or "de-Latn-DE" or "cmn-Hans-CN". Include as much info as you can without loss of precision. For example, if a model is trained on Scottish English, include the variant subtag: "en-Latn-GB-Scotland".',
|
||||
)
|
||||
str_val_equals_help(
|
||||
"export_min_stt_version",
|
||||
"minimum Coqui STT version (inclusive) the exported model is compatible with",
|
||||
)
|
||||
str_val_equals_help(
|
||||
"export_max_stt_version",
|
||||
"maximum Coqui STT version (inclusive) the exported model is compatible with",
|
||||
)
|
||||
str_val_equals_help(
|
||||
"export_description",
|
||||
"Freeform description of the model being exported. Markdown accepted. You can also leave this flag unchanged and edit the generated .md file directly. Useful things to describe are demographic and acoustic characteristics of the data used to train the model, any architectural changes, names of public datasets that were used when applicable, hyperparameters used for training, evaluation results on standard benchmark datasets, etc.",
|
||||
)
|
||||
|
||||
# Reporting
|
||||
|
||||
f.DEFINE_integer(
|
||||
"log_level",
|
||||
1,
|
||||
"log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"show_progressbar",
|
||||
True,
|
||||
"Show progress for training, validation and testing processes. Log level should be > 0.",
|
||||
)
|
||||
|
||||
f.DEFINE_boolean(
|
||||
"log_placement",
|
||||
False,
|
||||
"whether to log device placement of the operators to the console",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"report_count",
|
||||
5,
|
||||
"number of phrases for each of best WER, median WER and worst WER to print out during a WER report",
|
||||
)
|
||||
|
||||
f.DEFINE_string(
|
||||
"summary_dir",
|
||||
"",
|
||||
'target directory for TensorBoard summaries - defaults to directory "stt/summaries" within user\'s data home specified by the XDG Base Directory Specification',
|
||||
)
|
||||
|
||||
f.DEFINE_string(
|
||||
"test_output_file",
|
||||
"",
|
||||
"path to a file to save all src/decoded/distance/loss tuples generated during a test epoch",
|
||||
)
|
||||
|
||||
# Geometry
|
||||
|
||||
f.DEFINE_integer("n_hidden", 2048, "layer width to use when initialising layers")
|
||||
f.DEFINE_boolean(
|
||||
"layer_norm",
|
||||
False,
|
||||
"wether to use layer-normalization after each fully-connected layer (except the last one)",
|
||||
)
|
||||
|
||||
# Initialization
|
||||
|
||||
f.DEFINE_integer(
|
||||
"random_seed", 4568, "default random seed that is used to initialize variables"
|
||||
)
|
||||
|
||||
# Early Stopping
|
||||
|
||||
f.DEFINE_boolean(
|
||||
"early_stop",
|
||||
False,
|
||||
"Enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"es_epochs",
|
||||
25,
|
||||
"Number of epochs with no improvement after which training will be stopped. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point",
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"es_min_delta",
|
||||
0.05,
|
||||
"Minimum change in loss to qualify as an improvement. This value will also be used in Reduce learning rate on plateau",
|
||||
)
|
||||
|
||||
# Reduce learning rate on plateau
|
||||
|
||||
f.DEFINE_boolean(
|
||||
"reduce_lr_on_plateau",
|
||||
False,
|
||||
"Enable reducing the learning rate if a plateau is reached. This is the case if the validation loss did not improve for some epochs.",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"plateau_epochs",
|
||||
10,
|
||||
"Number of epochs to consider for RLROP. Has to be smaller than es_epochs from early stopping",
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"plateau_reduction",
|
||||
0.1,
|
||||
"Multiplicative factor to apply to the current learning rate if a plateau has occurred.",
|
||||
)
|
||||
f.DEFINE_boolean(
|
||||
"force_initialize_learning_rate",
|
||||
False,
|
||||
"Force re-initialization of learning rate which was previously reduced.",
|
||||
)
|
||||
|
||||
# Decoder
|
||||
|
||||
f.DEFINE_boolean(
|
||||
"bytes_output_mode",
|
||||
False,
|
||||
"enable Bytes Output Mode mode. When this is used the model outputs UTF-8 byte values directly rather than using an alphabet mapping. The --alphabet_config_path option will be ignored. See the training documentation for more details.",
|
||||
)
|
||||
f.DEFINE_string(
|
||||
"alphabet_config_path",
|
||||
"data/alphabet.txt",
|
||||
"path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.",
|
||||
)
|
||||
f.DEFINE_string("scorer_path", "", "path to the external scorer file.")
|
||||
f.DEFINE_alias("scorer", "scorer_path")
|
||||
f.DEFINE_integer(
|
||||
"beam_width",
|
||||
1024,
|
||||
"beam width used in the CTC decoder when building candidate transcriptions",
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"lm_alpha",
|
||||
0.931289039105002,
|
||||
"the alpha hyperparameter of the CTC decoder. Language Model weight.",
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"lm_beta",
|
||||
1.1834137581510284,
|
||||
"the beta hyperparameter of the CTC decoder. Word insertion weight.",
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"cutoff_prob",
|
||||
1.0,
|
||||
"only consider characters until this probability mass is reached. 1.0 = disabled.",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"cutoff_top_n",
|
||||
300,
|
||||
"only process this number of characters sorted by probability mass for each time step. If bigger than alphabet size, disabled.",
|
||||
)
|
||||
|
||||
# Inference mode
|
||||
|
||||
f.DEFINE_string(
|
||||
"one_shot_infer",
|
||||
"",
|
||||
"one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.",
|
||||
)
|
||||
|
||||
# Optimizer mode
|
||||
|
||||
f.DEFINE_float(
|
||||
"lm_alpha_max",
|
||||
5,
|
||||
"the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.",
|
||||
)
|
||||
f.DEFINE_float(
|
||||
"lm_beta_max",
|
||||
5,
|
||||
"the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.",
|
||||
)
|
||||
f.DEFINE_integer(
|
||||
"n_trials",
|
||||
2400,
|
||||
"the number of trials to run during hyperparameter optimization.",
|
||||
)
|
||||
|
||||
# Register validators for paths which require a file to be specified
|
||||
|
||||
f.register_validator(
|
||||
"alphabet_config_path",
|
||||
os.path.isfile,
|
||||
message="The file pointed to by --alphabet_config_path must exist and be readable.",
|
||||
)
|
||||
|
||||
f.register_validator(
|
||||
"one_shot_infer",
|
||||
lambda value: not value or os.path.isfile(value),
|
||||
message="The file pointed to by --one_shot_infer must exist and be readable.",
|
||||
)
|
||||
|
||||
|
||||
# sphinx-doc: training_ref_flags_end
|
|
@ -1,50 +0,0 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import progressbar
|
||||
|
||||
from .flags import FLAGS
|
||||
|
||||
# Logging functions
|
||||
# =================
|
||||
|
||||
|
||||
def prefix_print(prefix, message):
|
||||
print(prefix + ("\n" + prefix).join(message.split("\n")))
|
||||
|
||||
|
||||
def log_debug(message):
|
||||
if FLAGS.log_level == 0:
|
||||
prefix_print("D ", message)
|
||||
|
||||
|
||||
def log_info(message):
|
||||
if FLAGS.log_level <= 1:
|
||||
prefix_print("I ", message)
|
||||
|
||||
|
||||
def log_warn(message):
|
||||
if FLAGS.log_level <= 2:
|
||||
prefix_print("W ", message)
|
||||
|
||||
|
||||
def log_error(message):
|
||||
if FLAGS.log_level <= 3:
|
||||
prefix_print("E ", message)
|
||||
|
||||
|
||||
def create_progressbar(*args, **kwargs):
|
||||
# Progress bars in stdout by default
|
||||
if "fd" not in kwargs:
|
||||
kwargs["fd"] = sys.stdout
|
||||
|
||||
if FLAGS.show_progressbar:
|
||||
return progressbar.ProgressBar(*args, **kwargs)
|
||||
|
||||
return progressbar.NullBar(*args, **kwargs)
|
||||
|
||||
|
||||
def log_progress(message):
|
||||
if not FLAGS.show_progressbar:
|
||||
log_info(message)
|
Loading…
Reference in New Issue