Move tfv1 calls inside high-level functions

This commit is contained in:
Josh Meyer 2021-07-30 13:25:36 -04:00
parent df26eca4d2
commit 256af35a61
3 changed files with 9 additions and 12 deletions

View File

@ -196,14 +196,12 @@
"outputs": [],
"source": [
"from coqui_stt_training.train import train, early_training_checks\n",
"import tensorflow.compat.v1 as tfv1\n",
"\n",
"# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n",
"early_training_checks()\n",
"\n",
"tfv1.reset_default_graph()\n",
"train()"
]
},
@ -257,7 +255,6 @@
"source": [
"from coqui_stt_training.train import test\n",
"\n",
"tfv1.reset_default_graph()\n",
"test()"
]
}

View File

@ -198,14 +198,12 @@
"outputs": [],
"source": [
"from coqui_stt_training.train import train, early_training_checks\n",
"import tensorflow.compat.v1 as tfv1\n",
"\n",
"# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n",
"early_training_checks()\n",
"\n",
"tfv1.reset_default_graph()\n",
"train()"
]
},
@ -236,7 +234,6 @@
"source": [
"from coqui_stt_training.train import test\n",
"\n",
"tfv1.reset_default_graph()\n",
"test()"
]
}

View File

@ -522,6 +522,9 @@ def log_grads_and_vars(grads_and_vars):
def train():
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
exception_box = ExceptionBox()
# Create training and validation datasets
@ -896,6 +899,8 @@ def train():
def test():
tfv1.reset_default_graph()
samples = evaluate(Config.test_files, create_model)
if Config.test_output_file:
save_samples_json(samples, Config.test_output_file)
@ -1027,6 +1032,8 @@ def export():
"""
log_info("Exporting the model...")
tfv1.reset_default_graph()
inputs, outputs, _ = create_inference_graph(
batch_size=Config.export_batch_size,
n_steps=Config.n_steps,
@ -1176,6 +1183,8 @@ def package_zip():
def do_single_file_inference(input_file_path):
tfv1.reset_default_graph()
with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
@ -1253,20 +1262,15 @@ def main():
early_training_checks()
if Config.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
train()
if Config.test_files:
tfv1.reset_default_graph()
test()
if Config.export_dir and not Config.export_zip:
tfv1.reset_default_graph()
export()
if Config.export_zip:
tfv1.reset_default_graph()
Config.export_tflite = True
if listdir_remote(Config.export_dir):
@ -1279,7 +1283,6 @@ def main():
package_zip()
if Config.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(Config.one_shot_infer)