From 3baf5da2489130bd300c6859da13be7f60db8788 Mon Sep 17 00:00:00 2001 From: Andres Elizondo Date: Wed, 8 Jul 2020 12:04:50 -0500 Subject: [PATCH] Allows for generator training. --- precise/scripts/train_generated.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/precise/scripts/train_generated.py b/precise/scripts/train_generated.py index 2f0fa50..e9c5250 100644 --- a/precise/scripts/train_generated.py +++ b/precise/scripts/train_generated.py @@ -19,6 +19,7 @@ import numpy as np from contextlib import suppress from fitipy import Fitipy from keras.callbacks import LambdaCallback +from os import rename from os.path import splitext, join, basename from prettyparse import Usage from random import random, shuffle @@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript): self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None) from keras.callbacks import ModelCheckpoint - checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor, + checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor, save_best_only=args.save_best) epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch') self.epoch = epoch_fiti.read().read(0, int) @@ -224,16 +225,16 @@ class TrainGeneratedScript(BaseScript): def run(self): """Train the model on randomly generated batches""" - _, test_data = self.data.load(train=False, test=True) + _, test_data = self.data.load(train=True, test=True) try: - self.model.fit_generator( + self.model.fit( self.samples_to_batches(self.generate_samples(), self.args.batch_size), steps_per_epoch=self.args.steps_per_epoch, epochs=self.epoch + self.args.epochs, validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch ) finally: - self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format + self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format rename(self.args.model + '.h5', self.args.model) # Rename with original save_params(self.args.model)