Merge pull request #1927 from JRMeyer/tfv1-moving
Move tfv1 calls inside high-level functions
This commit is contained in:
commit
b77d33a108
|
@ -196,14 +196,12 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||||
"import tensorflow.compat.v1 as tfv1\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# use maximum one GPU\n",
|
"# use maximum one GPU\n",
|
||||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"early_training_checks()\n",
|
"early_training_checks()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tfv1.reset_default_graph()\n",
|
|
||||||
"train()"
|
"train()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -257,7 +255,6 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import test\n",
|
"from coqui_stt_training.train import test\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tfv1.reset_default_graph()\n",
|
|
||||||
"test()"
|
"test()"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,14 +198,12 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import train, early_training_checks\n",
|
"from coqui_stt_training.train import train, early_training_checks\n",
|
||||||
"import tensorflow.compat.v1 as tfv1\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# use maximum one GPU\n",
|
"# use maximum one GPU\n",
|
||||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"early_training_checks()\n",
|
"early_training_checks()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tfv1.reset_default_graph()\n",
|
|
||||||
"train()"
|
"train()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -236,7 +234,6 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from coqui_stt_training.train import test\n",
|
"from coqui_stt_training.train import test\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tfv1.reset_default_graph()\n",
|
|
||||||
"test()"
|
"test()"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -522,6 +522,9 @@ def log_grads_and_vars(grads_and_vars):
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
tfv1.set_random_seed(Config.random_seed)
|
||||||
|
|
||||||
exception_box = ExceptionBox()
|
exception_box = ExceptionBox()
|
||||||
|
|
||||||
# Create training and validation datasets
|
# Create training and validation datasets
|
||||||
|
@ -896,6 +899,8 @@ def train():
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
samples = evaluate(Config.test_files, create_model)
|
samples = evaluate(Config.test_files, create_model)
|
||||||
if Config.test_output_file:
|
if Config.test_output_file:
|
||||||
save_samples_json(samples, Config.test_output_file)
|
save_samples_json(samples, Config.test_output_file)
|
||||||
|
@ -1027,6 +1032,8 @@ def export():
|
||||||
"""
|
"""
|
||||||
log_info("Exporting the model...")
|
log_info("Exporting the model...")
|
||||||
|
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
inputs, outputs, _ = create_inference_graph(
|
inputs, outputs, _ = create_inference_graph(
|
||||||
batch_size=Config.export_batch_size,
|
batch_size=Config.export_batch_size,
|
||||||
n_steps=Config.n_steps,
|
n_steps=Config.n_steps,
|
||||||
|
@ -1176,6 +1183,8 @@ def package_zip():
|
||||||
|
|
||||||
|
|
||||||
def do_single_file_inference(input_file_path):
|
def do_single_file_inference(input_file_path):
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||||
|
|
||||||
|
@ -1253,20 +1262,15 @@ def main():
|
||||||
early_training_checks()
|
early_training_checks()
|
||||||
|
|
||||||
if Config.train_files:
|
if Config.train_files:
|
||||||
tfv1.reset_default_graph()
|
|
||||||
tfv1.set_random_seed(Config.random_seed)
|
|
||||||
train()
|
train()
|
||||||
|
|
||||||
if Config.test_files:
|
if Config.test_files:
|
||||||
tfv1.reset_default_graph()
|
|
||||||
test()
|
test()
|
||||||
|
|
||||||
if Config.export_dir and not Config.export_zip:
|
if Config.export_dir and not Config.export_zip:
|
||||||
tfv1.reset_default_graph()
|
|
||||||
export()
|
export()
|
||||||
|
|
||||||
if Config.export_zip:
|
if Config.export_zip:
|
||||||
tfv1.reset_default_graph()
|
|
||||||
Config.export_tflite = True
|
Config.export_tflite = True
|
||||||
|
|
||||||
if listdir_remote(Config.export_dir):
|
if listdir_remote(Config.export_dir):
|
||||||
|
@ -1279,7 +1283,6 @@ def main():
|
||||||
package_zip()
|
package_zip()
|
||||||
|
|
||||||
if Config.one_shot_infer:
|
if Config.one_shot_infer:
|
||||||
tfv1.reset_default_graph()
|
|
||||||
do_single_file_inference(Config.one_shot_infer)
|
do_single_file_inference(Config.one_shot_infer)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue