diff --git a/bin/run-ldc93s1.py b/bin/run-ldc93s1.py index e266b7a4..5a2746d8 100755 --- a/bin/run-ldc93s1.py +++ b/bin/run-ldc93s1.py @@ -2,8 +2,7 @@ import os from import_ldc93s1 import _download_and_preprocess_data as download_ldc from coqui_stt_training.util.config import initialize_globals_from_args -from coqui_stt_training.train import train, test, early_training_checks -import tensorflow.compat.v1 as tfv1 +from coqui_stt_training.train import train, test # only one GPU for only one training sample os.environ["CUDA_VISIBLE_DEVICES"] = "0" @@ -21,8 +20,6 @@ initialize_globals_from_args( epochs=200, ) -early_training_checks() - train() -tfv1.reset_default_graph() + test() diff --git a/notebooks/easy-transfer-learning.ipynb b/notebooks/easy-transfer-learning.ipynb index b83f1f80..4631db82 100644 --- a/notebooks/easy-transfer-learning.ipynb +++ b/notebooks/easy-transfer-learning.ipynb @@ -195,13 +195,11 @@ }, "outputs": [], "source": [ - "from coqui_stt_training.train import train, early_training_checks\n", + "from coqui_stt_training.train import train\n", "\n", "# use maximum one GPU\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", - "early_training_checks()\n", - "\n", "train()" ] }, diff --git a/notebooks/train-your-first-coqui-STT-model.ipynb b/notebooks/train-your-first-coqui-STT-model.ipynb index 2009dfa0..bcb10d89 100644 --- a/notebooks/train-your-first-coqui-STT-model.ipynb +++ b/notebooks/train-your-first-coqui-STT-model.ipynb @@ -197,13 +197,11 @@ }, "outputs": [], "source": [ - "from coqui_stt_training.train import train, early_training_checks\n", + "from coqui_stt_training.train import train\n", "\n", "# use maximum one GPU\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", - "early_training_checks()\n", - "\n", "train()" ] }, diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 6be4ef55..3d3b7177 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -522,6 +522,8 @@ def log_grads_and_vars(grads_and_vars): def train(): + early_training_checks() + tfv1.reset_default_graph() tfv1.set_random_seed(Config.random_seed) @@ -1266,7 +1268,6 @@ def early_training_checks(): def main(): initialize_globals_from_cli() - early_training_checks() if Config.train_files: train()