Merge pull request #1942 from coqui-ai/nc-api-boundary

Python training API cleanup, mark nodes known by native client
This commit is contained in:
Reuben Morais 2021-08-23 12:57:08 +02:00 committed by GitHub
commit 71da178138
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 30 deletions

View File

@ -2,8 +2,7 @@
import os import os
from import_ldc93s1 import _download_and_preprocess_data as download_ldc from import_ldc93s1 import _download_and_preprocess_data as download_ldc
from coqui_stt_training.util.config import initialize_globals_from_args from coqui_stt_training.util.config import initialize_globals_from_args
from coqui_stt_training.train import train, test, early_training_checks from coqui_stt_training.train import train, test
import tensorflow.compat.v1 as tfv1
# only one GPU for only one training sample # only one GPU for only one training sample
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
@ -21,8 +20,6 @@ initialize_globals_from_args(
epochs=200, epochs=200,
) )
early_training_checks()
train() train()
tfv1.reset_default_graph()
test() test()

View File

@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function
import sys import sys
import absl.app
import optuna import optuna
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
from coqui_stt_ctcdecoder import Scorer from coqui_stt_ctcdecoder import Scorer
from coqui_stt_training.evaluate import evaluate from coqui_stt_training.evaluate import evaluate
from coqui_stt_training.train import create_model from coqui_stt_training.train import create_model, early_training_checks
from coqui_stt_training.util.config import Config, initialize_globals_from_cli from coqui_stt_training.util.config import (
Config,
initialize_globals_from_cli,
log_error,
)
from coqui_stt_training.util.evaluate_tools import wer_cer_batch from coqui_stt_training.util.evaluate_tools import wer_cer_batch
from coqui_stt_training.util.flags import FLAGS, create_flags
from coqui_stt_training.util.logging import log_error
def character_based(): def character_based():
is_character_based = False is_character_based = False
if FLAGS.scorer_path: scorer = Scorer(
scorer = Scorer( Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet )
) is_character_based = scorer.is_utf8_mode()
is_character_based = scorer.is_utf8_mode()
return is_character_based return is_character_based
def objective(trial): def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max) Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max) Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max)
is_character_based = trial.study.user_attrs["is_character_based"] is_character_based = trial.study.user_attrs["is_character_based"]
samples = [] samples = []
for step, test_file in enumerate(FLAGS.test_files.split(",")): for step, test_file in enumerate(Config.test_files):
tfv1.reset_default_graph() tfv1.reset_default_graph()
current_samples = evaluate([test_file], create_model) current_samples = evaluate([test_file], create_model)
@ -51,10 +51,18 @@ def objective(trial):
return cer if is_character_based else wer return cer if is_character_based else wer
def main(_): def main():
initialize_globals_from_cli() initialize_globals_from_cli()
early_training_checks()
if not FLAGS.test_files: if not Config.scorer_path:
log_error(
"Missing --scorer_path: can't optimize scorer alpha and beta "
"parameters without a scorer!"
)
sys.exit(1)
if not Config.test_files:
log_error( log_error(
"You need to specify what files to use for evaluation via " "You need to specify what files to use for evaluation via "
"the --test_files flag." "the --test_files flag."
@ -65,7 +73,7 @@ def main(_):
study = optuna.create_study() study = optuna.create_study()
study.set_user_attr("is_character_based", is_character_based) study.set_user_attr("is_character_based", is_character_based)
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials) study.optimize(objective, n_jobs=1, n_trials=Config.n_trials)
print( print(
"Best params: lm_alpha={} and lm_beta={} with WER={}".format( "Best params: lm_alpha={} and lm_beta={} with WER={}".format(
study.best_params["lm_alpha"], study.best_params["lm_alpha"],
@ -76,5 +84,4 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
create_flags() main()
absl.app.run(main)

View File

@ -195,13 +195,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from coqui_stt_training.train import train, early_training_checks\n", "from coqui_stt_training.train import train\n",
"\n", "\n",
"# use maximum one GPU\n", "# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n", "\n",
"early_training_checks()\n",
"\n",
"train()" "train()"
] ]
}, },

View File

@ -197,13 +197,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from coqui_stt_training.train import train, early_training_checks\n", "from coqui_stt_training.train import train\n",
"\n", "\n",
"# use maximum one GPU\n", "# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n", "\n",
"early_training_checks()\n",
"\n",
"train()" "train()"
] ]
}, },

View File

@ -522,6 +522,8 @@ def log_grads_and_vars(grads_and_vars):
def train(): def train():
early_training_checks()
tfv1.reset_default_graph() tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed) tfv1.set_random_seed(Config.random_seed)
@ -910,17 +912,28 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
batch_size = batch_size if batch_size > 0 else None batch_size = batch_size if batch_size > 0 else None
# Create feature computation graph # Create feature computation graph
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
input_samples = tfv1.placeholder( input_samples = tfv1.placeholder(
tf.float32, [Config.audio_window_samples], "input_samples" tf.float32, [Config.audio_window_samples], "input_samples"
) )
samples = tf.expand_dims(input_samples, -1) samples = tf.expand_dims(input_samples, -1)
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate) mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
mfccs = tf.identity(mfccs, name="mfccs") mfccs = tf.identity(mfccs, name="mfccs")
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
# This shape is read by the native_client in STT_CreateModel to know the # This shape is read by the native_client in STT_CreateModel to know the
# value of n_steps, n_context and n_input. Make sure you update the code # value of n_steps, n_context and n_input. Make sure you update the code
# there if this shape is changed. # there if this shape is changed.
#
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
input_tensor = tfv1.placeholder( input_tensor = tfv1.placeholder(
tf.float32, tf.float32,
[ [
@ -931,15 +944,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
], ],
name="input_node", name="input_node",
) )
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths") seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
if batch_size <= 0: if batch_size <= 0:
# no state management since n_step is expected to be dynamic too (see below) # no state management since n_step is expected to be dynamic too (see below)
previous_state = None previous_state = None
else: else:
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
previous_state_c = tfv1.placeholder( previous_state_c = tfv1.placeholder(
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c" tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
) )
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
previous_state_h = tfv1.placeholder( previous_state_h = tfv1.placeholder(
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h" tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
) )
@ -969,6 +991,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# TF Lite runtime will check that input dimensions are 1, 2 or 4 # TF Lite runtime will check that input dimensions are 1, 2 or 4
# by default we get 3, the middle one being batch_size which is forced to # by default we get 3, the middle one being batch_size which is forced to
# one on inference graph, so remove that dimension # one on inference graph, so remove that dimension
#
# native_client: this node's name and shape are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
if tflite: if tflite:
logits = tf.squeeze(logits, [1]) logits = tf.squeeze(logits, [1])
@ -1043,6 +1069,9 @@ def export():
graph_version = int(file_relative_read("GRAPH_VERSION").strip()) graph_version = int(file_relative_read("GRAPH_VERSION").strip())
assert graph_version > 0 assert graph_version > 0
# native_client: these nodes's names and shapes are part of the API boundary
# with the native client, if you change them you should sync changes with
# the C++ code.
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version") outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
outputs["metadata_sample_rate"] = tf.constant( outputs["metadata_sample_rate"] = tf.constant(
[Config.audio_sample_rate], name="metadata_sample_rate" [Config.audio_sample_rate], name="metadata_sample_rate"
@ -1266,7 +1295,6 @@ def early_training_checks():
def main(): def main():
initialize_globals_from_cli() initialize_globals_from_cli()
early_training_checks()
if Config.train_files: if Config.train_files:
train() train()