Move early_training_checks to train function
This commit is contained in:
parent
ad7335db0e
commit
f90408d3ab
@ -2,8 +2,7 @@
|
|||||||
import os
|
import os
|
||||||
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
from import_ldc93s1 import _download_and_preprocess_data as download_ldc
|
||||||
from coqui_stt_training.util.config import initialize_globals_from_args
|
from coqui_stt_training.util.config import initialize_globals_from_args
|
||||||
from coqui_stt_training.train import train, test, early_training_checks
|
from coqui_stt_training.train import train, test
|
||||||
import tensorflow.compat.v1 as tfv1
|
|
||||||
|
|
||||||
# only one GPU for only one training sample
|
# only one GPU for only one training sample
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
@ -21,8 +20,6 @@ initialize_globals_from_args(
|
|||||||
epochs=200,
|
epochs=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
early_training_checks()
|
|
||||||
|
|
||||||
train()
|
train()
|
||||||
tfv1.reset_default_graph()
|
|
||||||
test()
|
test()
|
||||||
|
@ -195,13 +195,11 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
"from coqui_stt_training.train import train\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# use maximum one GPU\n",
|
"# use maximum one GPU\n",
|
||||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"early_training_checks()\n",
|
|
||||||
"\n",
|
|
||||||
"train()"
|
"train()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -197,13 +197,11 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
"from coqui_stt_training.train import train\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# use maximum one GPU\n",
|
"# use maximum one GPU\n",
|
||||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"early_training_checks()\n",
|
|
||||||
"\n",
|
|
||||||
"train()"
|
"train()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -522,6 +522,8 @@ def log_grads_and_vars(grads_and_vars):
|
|||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
early_training_checks()
|
||||||
|
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
tfv1.set_random_seed(Config.random_seed)
|
tfv1.set_random_seed(Config.random_seed)
|
||||||
|
|
||||||
@ -1266,7 +1268,6 @@ def early_training_checks():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
initialize_globals_from_cli()
|
initialize_globals_from_cli()
|
||||||
early_training_checks()
|
|
||||||
|
|
||||||
if Config.train_files:
|
if Config.train_files:
|
||||||
train()
|
train()
|
||||||
|
Loading…
Reference in New Issue
Block a user