Merge pull request #1942 from coqui-ai/nc-api-boundary
Python training API cleanup, mark nodes known by native client
This commit is contained in:
commit
71da178138
|
@ -2,8 +2,7 @@
|
||||||
import os
|
import os
|
||||||
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
||||||
from coqui_stt_training.util.config import initialize_globals_from_args
|
from coqui_stt_training.util.config import initialize_globals_from_args
|
||||||
from coqui_stt_training.train import train, test, early_training_checks
|
from coqui_stt_training.train import train, test
|
||||||
import tensorflow.compat.v1 as tfv1
|
|
||||||
|
|
||||||
# only one GPU for only one training sample
|
# only one GPU for only one training sample
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
@ -21,8 +20,6 @@ initialize_globals_from_args(
|
||||||
epochs=200,
|
epochs=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
early_training_checks()
|
|
||||||
|
|
||||||
train()
|
train()
|
||||||
tfv1.reset_default_graph()
|
|
||||||
test()
|
test()
|
||||||
|
|
|
@ -4,36 +4,36 @@ from __future__ import absolute_import, print_function
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import absl.app
|
|
||||||
import optuna
|
import optuna
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
from coqui_stt_ctcdecoder import Scorer
|
from coqui_stt_ctcdecoder import Scorer
|
||||||
from coqui_stt_training.evaluate import evaluate
|
from coqui_stt_training.evaluate import evaluate
|
||||||
from coqui_stt_training.train import create_model
|
from coqui_stt_training.train import create_model, early_training_checks
|
||||||
from coqui_stt_training.util.config import Config, initialize_globals_from_cli
|
from coqui_stt_training.util.config import (
|
||||||
|
Config,
|
||||||
|
initialize_globals_from_cli,
|
||||||
|
log_error,
|
||||||
|
)
|
||||||
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
from coqui_stt_training.util.evaluate_tools import wer_cer_batch
|
||||||
from coqui_stt_training.util.flags import FLAGS, create_flags
|
|
||||||
from coqui_stt_training.util.logging import log_error
|
|
||||||
|
|
||||||
|
|
||||||
def character_based():
|
def character_based():
|
||||||
is_character_based = False
|
is_character_based = False
|
||||||
if FLAGS.scorer_path:
|
scorer = Scorer(
|
||||||
scorer = Scorer(
|
Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet
|
||||||
FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet
|
)
|
||||||
)
|
is_character_based = scorer.is_utf8_mode()
|
||||||
is_character_based = scorer.is_utf8_mode()
|
|
||||||
return is_character_based
|
return is_character_based
|
||||||
|
|
||||||
|
|
||||||
def objective(trial):
|
def objective(trial):
|
||||||
FLAGS.lm_alpha = trial.suggest_uniform("lm_alpha", 0, FLAGS.lm_alpha_max)
|
Config.lm_alpha = trial.suggest_uniform("lm_alpha", 0, Config.lm_alpha_max)
|
||||||
FLAGS.lm_beta = trial.suggest_uniform("lm_beta", 0, FLAGS.lm_beta_max)
|
Config.lm_beta = trial.suggest_uniform("lm_beta", 0, Config.lm_beta_max)
|
||||||
|
|
||||||
is_character_based = trial.study.user_attrs["is_character_based"]
|
is_character_based = trial.study.user_attrs["is_character_based"]
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for step, test_file in enumerate(FLAGS.test_files.split(",")):
|
for step, test_file in enumerate(Config.test_files):
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
current_samples = evaluate([test_file], create_model)
|
current_samples = evaluate([test_file], create_model)
|
||||||
|
@ -51,10 +51,18 @@ def objective(trial):
|
||||||
return cer if is_character_based else wer
|
return cer if is_character_based else wer
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main():
|
||||||
initialize_globals_from_cli()
|
initialize_globals_from_cli()
|
||||||
|
early_training_checks()
|
||||||
|
|
||||||
if not FLAGS.test_files:
|
if not Config.scorer_path:
|
||||||
|
log_error(
|
||||||
|
"Missing --scorer_path: can't optimize scorer alpha and beta "
|
||||||
|
"parameters without a scorer!"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not Config.test_files:
|
||||||
log_error(
|
log_error(
|
||||||
"You need to specify what files to use for evaluation via "
|
"You need to specify what files to use for evaluation via "
|
||||||
"the --test_files flag."
|
"the --test_files flag."
|
||||||
|
@ -65,7 +73,7 @@ def main(_):
|
||||||
|
|
||||||
study = optuna.create_study()
|
study = optuna.create_study()
|
||||||
study.set_user_attr("is_character_based", is_character_based)
|
study.set_user_attr("is_character_based", is_character_based)
|
||||||
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
|
study.optimize(objective, n_jobs=1, n_trials=Config.n_trials)
|
||||||
print(
|
print(
|
||||||
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
|
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
|
||||||
study.best_params["lm_alpha"],
|
study.best_params["lm_alpha"],
|
||||||
|
@ -76,5 +84,4 @@ def main(_):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
create_flags()
|
main()
|
||||||
absl.app.run(main)
|
|
||||||
|
|
|
@ -195,13 +195,11 @@
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
"from coqui_stt_training.train import train\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# use maximum one GPU\n",
|
"# use maximum one GPU\n",
|
||||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"early_training_checks()\n",
|
|
||||||
"\n",
|
|
||||||
"train()"
|
"train()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -197,13 +197,11 @@
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
"from coqui_stt_training.train import train\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# use maximum one GPU\n",
|
"# use maximum one GPU\n",
|
||||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"early_training_checks()\n",
|
|
||||||
"\n",
|
|
||||||
"train()"
|
"train()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
|
@ -522,6 +522,8 @@ def log_grads_and_vars(grads_and_vars):
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
early_training_checks()
|
||||||
|
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
tfv1.set_random_seed(Config.random_seed)
|
tfv1.set_random_seed(Config.random_seed)
|
||||||
|
|
||||||
|
@ -910,17 +912,28 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
batch_size = batch_size if batch_size > 0 else None
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
|
|
||||||
# Create feature computation graph
|
# Create feature computation graph
|
||||||
|
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
input_samples = tfv1.placeholder(
|
input_samples = tfv1.placeholder(
|
||||||
tf.float32, [Config.audio_window_samples], "input_samples"
|
tf.float32, [Config.audio_window_samples], "input_samples"
|
||||||
)
|
)
|
||||||
samples = tf.expand_dims(input_samples, -1)
|
samples = tf.expand_dims(input_samples, -1)
|
||||||
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
mfccs, _ = audio_to_features(samples, Config.audio_sample_rate)
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
mfccs = tf.identity(mfccs, name="mfccs")
|
mfccs = tf.identity(mfccs, name="mfccs")
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
# This shape is read by the native_client in STT_CreateModel to know the
|
# This shape is read by the native_client in STT_CreateModel to know the
|
||||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||||
# there if this shape is changed.
|
# there if this shape is changed.
|
||||||
|
#
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
input_tensor = tfv1.placeholder(
|
input_tensor = tfv1.placeholder(
|
||||||
tf.float32,
|
tf.float32,
|
||||||
[
|
[
|
||||||
|
@ -931,15 +944,24 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
],
|
],
|
||||||
name="input_node",
|
name="input_node",
|
||||||
)
|
)
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
|
seq_length = tfv1.placeholder(tf.int32, [batch_size], name="input_lengths")
|
||||||
|
|
||||||
if batch_size <= 0:
|
if batch_size <= 0:
|
||||||
# no state management since n_step is expected to be dynamic too (see below)
|
# no state management since n_step is expected to be dynamic too (see below)
|
||||||
previous_state = None
|
previous_state = None
|
||||||
else:
|
else:
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
previous_state_c = tfv1.placeholder(
|
previous_state_c = tfv1.placeholder(
|
||||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
|
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_c"
|
||||||
)
|
)
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
previous_state_h = tfv1.placeholder(
|
previous_state_h = tfv1.placeholder(
|
||||||
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
|
tf.float32, [batch_size, Config.n_cell_dim], name="previous_state_h"
|
||||||
)
|
)
|
||||||
|
@ -969,6 +991,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||||
# by default we get 3, the middle one being batch_size which is forced to
|
# by default we get 3, the middle one being batch_size which is forced to
|
||||||
# one on inference graph, so remove that dimension
|
# one on inference graph, so remove that dimension
|
||||||
|
#
|
||||||
|
# native_client: this node's name and shape are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
if tflite:
|
if tflite:
|
||||||
logits = tf.squeeze(logits, [1])
|
logits = tf.squeeze(logits, [1])
|
||||||
|
|
||||||
|
@ -1043,6 +1069,9 @@ def export():
|
||||||
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
graph_version = int(file_relative_read("GRAPH_VERSION").strip())
|
||||||
assert graph_version > 0
|
assert graph_version > 0
|
||||||
|
|
||||||
|
# native_client: these nodes's names and shapes are part of the API boundary
|
||||||
|
# with the native client, if you change them you should sync changes with
|
||||||
|
# the C++ code.
|
||||||
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
outputs["metadata_version"] = tf.constant([graph_version], name="metadata_version")
|
||||||
outputs["metadata_sample_rate"] = tf.constant(
|
outputs["metadata_sample_rate"] = tf.constant(
|
||||||
[Config.audio_sample_rate], name="metadata_sample_rate"
|
[Config.audio_sample_rate], name="metadata_sample_rate"
|
||||||
|
@ -1266,7 +1295,6 @@ def early_training_checks():
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
initialize_globals_from_cli()
|
initialize_globals_from_cli()
|
||||||
early_training_checks()
|
|
||||||
|
|
||||||
if Config.train_files:
|
if Config.train_files:
|
||||||
train()
|
train()
|
||||||
|
|
Loading…
Reference in New Issue