From 256af35a614a41991a1c82279b2a20fe776dcaa8 Mon Sep 17 00:00:00 2001 From: Josh Meyer Date: Fri, 30 Jul 2021 13:25:36 -0400 Subject: [PATCH] Move tfv1 calls inside high-level functions --- notebooks/easy-transfer-learning.ipynb | 3 --- notebooks/train-your-first-coqui-STT-model.ipynb | 3 --- training/coqui_stt_training/train.py | 15 +++++++++------ 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/notebooks/easy-transfer-learning.ipynb b/notebooks/easy-transfer-learning.ipynb index 2f5f8d7b..b83f1f80 100644 --- a/notebooks/easy-transfer-learning.ipynb +++ b/notebooks/easy-transfer-learning.ipynb @@ -196,14 +196,12 @@ "outputs": [], "source": [ "from coqui_stt_training.train import train, early_training_checks\n", - "import tensorflow.compat.v1 as tfv1\n", "\n", "# use maximum one GPU\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "early_training_checks()\n", "\n", - "tfv1.reset_default_graph()\n", "train()" ] }, @@ -257,7 +255,6 @@ "source": [ "from coqui_stt_training.train import test\n", "\n", - "tfv1.reset_default_graph()\n", "test()" ] } diff --git a/notebooks/train-your-first-coqui-STT-model.ipynb b/notebooks/train-your-first-coqui-STT-model.ipynb index 5eccccc8..2009dfa0 100644 --- a/notebooks/train-your-first-coqui-STT-model.ipynb +++ b/notebooks/train-your-first-coqui-STT-model.ipynb @@ -198,14 +198,12 @@ "outputs": [], "source": [ "from coqui_stt_training.train import train, early_training_checks\n", - "import tensorflow.compat.v1 as tfv1\n", "\n", "# use maximum one GPU\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "early_training_checks()\n", "\n", - "tfv1.reset_default_graph()\n", "train()" ] }, @@ -236,7 +234,6 @@ "source": [ "from coqui_stt_training.train import test\n", "\n", - "tfv1.reset_default_graph()\n", "test()" ] } diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index b6cfaec9..42b1767c 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -522,6 +522,9 @@ def log_grads_and_vars(grads_and_vars): def train(): + tfv1.reset_default_graph() + tfv1.set_random_seed(Config.random_seed) + exception_box = ExceptionBox() # Create training and validation datasets @@ -896,6 +899,8 @@ def train(): def test(): + tfv1.reset_default_graph() + samples = evaluate(Config.test_files, create_model) if Config.test_output_file: save_samples_json(samples, Config.test_output_file) @@ -1027,6 +1032,8 @@ def export(): """ log_info("Exporting the model...") + tfv1.reset_default_graph() + inputs, outputs, _ = create_inference_graph( batch_size=Config.export_batch_size, n_steps=Config.n_steps, @@ -1176,6 +1183,8 @@ def package_zip(): def do_single_file_inference(input_file_path): + tfv1.reset_default_graph() + with tfv1.Session(config=Config.session_config) as session: inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) @@ -1253,20 +1262,15 @@ def main(): early_training_checks() if Config.train_files: - tfv1.reset_default_graph() - tfv1.set_random_seed(Config.random_seed) train() if Config.test_files: - tfv1.reset_default_graph() test() if Config.export_dir and not Config.export_zip: - tfv1.reset_default_graph() export() if Config.export_zip: - tfv1.reset_default_graph() Config.export_tflite = True if listdir_remote(Config.export_dir): @@ -1279,7 +1283,6 @@ def main(): package_zip() if Config.one_shot_infer: - tfv1.reset_default_graph() do_single_file_inference(Config.one_shot_infer)