Move tfv1 calls inside high-level functions
This commit is contained in:
parent
df26eca4d2
commit
256af35a61
|
@ -196,14 +196,12 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||
"import tensorflow.compat.v1 as tfv1\n",
|
||||
"\n",
|
||||
"# use maximum one GPU\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"\n",
|
||||
"early_training_checks()\n",
|
||||
"\n",
|
||||
"tfv1.reset_default_graph()\n",
|
||||
"train()"
|
||||
]
|
||||
},
|
||||
|
@ -257,7 +255,6 @@
|
|||
"source": [
|
||||
"from coqui_stt_training.train import test\n",
|
||||
"\n",
|
||||
"tfv1.reset_default_graph()\n",
|
||||
"test()"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -198,14 +198,12 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||
"import tensorflow.compat.v1 as tfv1\n",
|
||||
"\n",
|
||||
"# use maximum one GPU\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||
"\n",
|
||||
"early_training_checks()\n",
|
||||
"\n",
|
||||
"tfv1.reset_default_graph()\n",
|
||||
"train()"
|
||||
]
|
||||
},
|
||||
|
@ -236,7 +234,6 @@
|
|||
"source": [
|
||||
"from coqui_stt_training.train import test\n",
|
||||
"\n",
|
||||
"tfv1.reset_default_graph()\n",
|
||||
"test()"
|
||||
]
|
||||
}
|
||||
|
|
|
@ -522,6 +522,9 @@ def log_grads_and_vars(grads_and_vars):
|
|||
|
||||
|
||||
def train():
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
|
||||
exception_box = ExceptionBox()
|
||||
|
||||
# Create training and validation datasets
|
||||
|
@ -896,6 +899,8 @@ def train():
|
|||
|
||||
|
||||
def test():
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
samples = evaluate(Config.test_files, create_model)
|
||||
if Config.test_output_file:
|
||||
save_samples_json(samples, Config.test_output_file)
|
||||
|
@ -1027,6 +1032,8 @@ def export():
|
|||
"""
|
||||
log_info("Exporting the model...")
|
||||
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(
|
||||
batch_size=Config.export_batch_size,
|
||||
n_steps=Config.n_steps,
|
||||
|
@ -1176,6 +1183,8 @@ def package_zip():
|
|||
|
||||
|
||||
def do_single_file_inference(input_file_path):
|
||||
tfv1.reset_default_graph()
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
|
@ -1253,20 +1262,15 @@ def main():
|
|||
early_training_checks()
|
||||
|
||||
if Config.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
tfv1.set_random_seed(Config.random_seed)
|
||||
train()
|
||||
|
||||
if Config.test_files:
|
||||
tfv1.reset_default_graph()
|
||||
test()
|
||||
|
||||
if Config.export_dir and not Config.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
export()
|
||||
|
||||
if Config.export_zip:
|
||||
tfv1.reset_default_graph()
|
||||
Config.export_tflite = True
|
||||
|
||||
if listdir_remote(Config.export_dir):
|
||||
|
@ -1279,7 +1283,6 @@ def main():
|
|||
package_zip()
|
||||
|
||||
if Config.one_shot_infer:
|
||||
tfv1.reset_default_graph()
|
||||
do_single_file_inference(Config.one_shot_infer)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue