Move early_training_checks to train function

This commit is contained in:
Reuben Morais 2021-08-19 18:33:32 +02:00
parent ad7335db0e
commit f90408d3ab
4 changed files with 6 additions and 12 deletions

View File

@ -2,8 +2,7 @@
import os import os
from import_ldc93s1 import _download_and_preprocess_data as download_ldc from import_ldc93s1 import _download_and_preprocess_data as download_ldc
from coqui_stt_training.util.config import initialize_globals_from_args from coqui_stt_training.util.config import initialize_globals_from_args
from coqui_stt_training.train import train, test, early_training_checks from coqui_stt_training.train import train, test
import tensorflow.compat.v1 as tfv1
# only one GPU for only one training sample # only one GPU for only one training sample
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
@ -21,8 +20,6 @@ initialize_globals_from_args(
epochs=200, epochs=200,
) )
early_training_checks()
train() train()
tfv1.reset_default_graph()
test() test()

View File

@ -195,13 +195,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from coqui_stt_training.train import train, early_training_checks\n", "from coqui_stt_training.train import train\n",
"\n", "\n",
"# use maximum one GPU\n", "# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n", "\n",
"early_training_checks()\n",
"\n",
"train()" "train()"
] ]
}, },

View File

@ -197,13 +197,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from coqui_stt_training.train import train, early_training_checks\n", "from coqui_stt_training.train import train\n",
"\n", "\n",
"# use maximum one GPU\n", "# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n", "\n",
"early_training_checks()\n",
"\n",
"train()" "train()"
] ]
}, },

View File

@ -522,6 +522,8 @@ def log_grads_and_vars(grads_and_vars):
def train(): def train():
early_training_checks()
tfv1.reset_default_graph() tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed) tfv1.set_random_seed(Config.random_seed)
@ -1266,7 +1268,6 @@ def early_training_checks():
def main(): def main():
initialize_globals_from_cli() initialize_globals_from_cli()
early_training_checks()
if Config.train_files: if Config.train_files:
train() train()