Move tfv1 calls inside high-level functions

This commit is contained in:
Josh Meyer 2021-07-30 13:25:36 -04:00
parent df26eca4d2
commit 256af35a61
3 changed files with 9 additions and 12 deletions

View File

@ -196,14 +196,12 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from coqui_stt_training.train import train, early_training_checks\n", "from coqui_stt_training.train import train, early_training_checks\n",
"import tensorflow.compat.v1 as tfv1\n",
"\n", "\n",
"# use maximum one GPU\n", "# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n", "\n",
"early_training_checks()\n", "early_training_checks()\n",
"\n", "\n",
"tfv1.reset_default_graph()\n",
"train()" "train()"
] ]
}, },
@ -257,7 +255,6 @@
"source": [ "source": [
"from coqui_stt_training.train import test\n", "from coqui_stt_training.train import test\n",
"\n", "\n",
"tfv1.reset_default_graph()\n",
"test()" "test()"
] ]
} }

View File

@ -198,14 +198,12 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from coqui_stt_training.train import train, early_training_checks\n", "from coqui_stt_training.train import train, early_training_checks\n",
"import tensorflow.compat.v1 as tfv1\n",
"\n", "\n",
"# use maximum one GPU\n", "# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n", "\n",
"early_training_checks()\n", "early_training_checks()\n",
"\n", "\n",
"tfv1.reset_default_graph()\n",
"train()" "train()"
] ]
}, },
@ -236,7 +234,6 @@
"source": [ "source": [
"from coqui_stt_training.train import test\n", "from coqui_stt_training.train import test\n",
"\n", "\n",
"tfv1.reset_default_graph()\n",
"test()" "test()"
] ]
} }

View File

@ -522,6 +522,9 @@ def log_grads_and_vars(grads_and_vars):
def train(): def train():
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
exception_box = ExceptionBox() exception_box = ExceptionBox()
# Create training and validation datasets # Create training and validation datasets
@ -896,6 +899,8 @@ def train():
def test(): def test():
tfv1.reset_default_graph()
samples = evaluate(Config.test_files, create_model) samples = evaluate(Config.test_files, create_model)
if Config.test_output_file: if Config.test_output_file:
save_samples_json(samples, Config.test_output_file) save_samples_json(samples, Config.test_output_file)
@ -1027,6 +1032,8 @@ def export():
""" """
log_info("Exporting the model...") log_info("Exporting the model...")
tfv1.reset_default_graph()
inputs, outputs, _ = create_inference_graph( inputs, outputs, _ = create_inference_graph(
batch_size=Config.export_batch_size, batch_size=Config.export_batch_size,
n_steps=Config.n_steps, n_steps=Config.n_steps,
@ -1176,6 +1183,8 @@ def package_zip():
def do_single_file_inference(input_file_path): def do_single_file_inference(input_file_path):
tfv1.reset_default_graph()
with tfv1.Session(config=Config.session_config) as session: with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
@ -1253,20 +1262,15 @@ def main():
early_training_checks() early_training_checks()
if Config.train_files: if Config.train_files:
tfv1.reset_default_graph()
tfv1.set_random_seed(Config.random_seed)
train() train()
if Config.test_files: if Config.test_files:
tfv1.reset_default_graph()
test() test()
if Config.export_dir and not Config.export_zip: if Config.export_dir and not Config.export_zip:
tfv1.reset_default_graph()
export() export()
if Config.export_zip: if Config.export_zip:
tfv1.reset_default_graph()
Config.export_tflite = True Config.export_tflite = True
if listdir_remote(Config.export_dir): if listdir_remote(Config.export_dir):
@ -1279,7 +1283,6 @@ def main():
package_zip() package_zip()
if Config.one_shot_infer: if Config.one_shot_infer:
tfv1.reset_default_graph()
do_single_file_inference(Config.one_shot_infer) do_single_file_inference(Config.one_shot_infer)