Merge pull request #1942 from coqui-ai/nc-api-boundary
Python training API cleanup, mark nodes known by native client
This commit is contained in:
commit
71da178138
|
@ -2,8 +2,7 @@
|
|||
import os
|
||||
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
||||
from coqui_stt_training.util.config import initialize_globals_from_args
|
||||
from coqui_stt_training.train import train, test, early_training_checks
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from coqui_stt_training.train import train, test
|
||||
|
||||
# only one GPU for only one training sample
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
@ -21,8 +20,6 @@ initialize_globals_from_args(
|
|||
epochs=200,
|
||||
)
|
||||
|
||||
early_training_checks()
|
||||
|
||||
train()
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
test()
|
||||
|
|
|
@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function
|
|||
|
||||
import sys
|
||||
|
||||
import absl.app
|
||||
import optuna
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
from coqui_stt_ctcdecoder import Scorer
|
||||
from coqui_stt_training.evaluate import evaluate
|
||||
from coqui_stt_training.train import create_model
|
||||
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
||||
from coqui_stt_training.train import create_model, early_training_checks
|
||||
from coqui_stt_training.util.config import (
|
||||
Config,
|
||||
initialize_globals_from_cli,
|
||||
log_error,
|
||||
)
|
||||
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
||||
from coqui_stt_training.util.logging import log_error
|
||||
|
||||
|
||||
def character_based():
|
||||
is_character_based = False
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(
|
||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
||||
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||
)
|
||||
is_character_based = scorer.is_utf8_mode()
|
||||
return is_character_based
|
||||
|
||||
|
||||
def objective(trial):
|
||||
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max)
|
||||
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max)
|
||||
Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max)
|
||||
Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max)
|
||||
|
||||
is_character_based = trial.study.user_attrs["is_character_based"]
|
||||
|
||||
samples = []
|
||||
for step, test_file in enumerate(FLAGS.test_files.split(",")):
|
||||
for step, test_file in enumerate(Config.test_files):
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
current_samples = evaluate([test_file], create_model)
|
||||
|
@ -51,10 +51,18 @@ def objective(trial):
|
|||
return cer if is_character_based else wer
|
||||
|
||||
|
||||
def main(_):
|
||||
def main():
|
||||
initialize_globals_from_cli()
|
||||
early_training_checks()
|
||||
|
||||
if not FLAGS.test_files:
|
||||
if not Config.scorer_path:
|
||||
log_error(
|
||||
"Missing --scorer_path: can't optimize scorer alpha and beta "
|
||||
"parameters without a scorer!"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not Config.test_files:
|
||||
log_error(
|
||||
"You need to specify what files to use for evaluation via "
|
||||
"the --test_files flag."
|
||||
|
@ -65,7 +73,7 @@ def main(_):
|
|||
|
||||
study = optuna.create_study()
|
||||
study.set_user_attr("is_character_based", is_character_based)
|
||||
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
|
||||
study.optimize(objective, n_jobs=1, n_trials=Config.n_trials)
|
||||
print(
|
||||
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
|
||||
study.best_params["lm_alpha"],
|
||||
|
@ -76,5 +84,4 @@ def main(_):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_flags()
|
||||
absl.app.run(main)
|
||||
main()
|
||||
|
|
|
@ -195,13 +195,11 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||
"from coqui_stt_training.train import train\n",
|
||||
"\n",
|
||||
"# use maximum one GPU\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"\n",
|
||||
"early_training_checks()\n",
|
||||
"\n",
|
||||
"train()"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -197,13 +197,11 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||
"from coqui_stt_training.train import train\n",
|
||||
"\n",
|
||||
"# use maximum one GPU\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"\n",
|
||||
"early_training_checks()\n",
|
||||
"\n",
|
||||
"train()"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -522,6 +522,8 @@ def log_grads_and_vars(grads_and_vars):
|
|||
|
||||
|
||||
def train():
|
||||
early_training_checks()
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
|
||||
|
@ -910,17 +912,28 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
input_samples = tfv1.placeholder(
|
||||
tf.float32, [Config.audio_window_samples], "input_samples"
|
||||
)
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
mfccs = tf.identity(mfccs, name="mfccs")
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
# This shape is read by the native_client in STT_CreateModel to know the
|
||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||
# there if this shape is changed.
|
||||
#
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
input_tensor = tfv1.placeholder(
|
||||
tf.float32,
|
||||
[
|
||||
|
@ -931,15 +944,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
],
|
||||
name="input_node",
|
||||
)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
|
||||
|
||||
if batch_size <= 0:
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = None
|
||||
else:
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
previous_state_c = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
|
||||
)
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
previous_state_h = tfv1.placeholder(
|
||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
|
||||
)
|
||||
|
@ -969,6 +991,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||
# by default we get 3, the middle one being batch_size which is forced to
|
||||
# one on inference graph, so remove that dimension
|
||||
#
|
||||
# native_client: this node's name and shape are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
if tflite:
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
|
@ -1043,6 +1069,9 @@ def export():
|
|||
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
||||
assert graph_version > 0
|
||||
|
||||
# native_client: these nodes's names and shapes are part of the API boundary
|
||||
# with the native client, if you change them you should sync changes with
|
||||
# the C++ code.
|
||||
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
||||
outputs["metadata_sample_rate"] = tf.constant(
|
||||
[Config.audio_sample_rate], name="metadata_sample_rate"
|
||||
|
@ -1266,7 +1295,6 @@ def early_training_checks():
|
|||
|
||||
def main():
|
||||
initialize_globals_from_cli()
|
||||
early_training_checks()
|
||||
|
||||
if Config.train_files:
|
||||
train()
|
||||
|
|
Loading…
Reference in New Issue