Allows for generator training.

This commit is contained in:
Andres Elizondo 2020-07-08 12:04:50 -05:00
parent e2a3c2dc58
commit 3baf5da248

View File

@ -19,6 +19,7 @@ import numpy as np
from contextlib import suppress from contextlib import suppress
from fitipy import Fitipy from fitipy import Fitipy
from keras.callbacks import LambdaCallback from keras.callbacks import LambdaCallback
from os import rename
from os.path import splitext, join, basename from os.path import splitext, join, basename
from prettyparse import Usage from prettyparse import Usage
from random import random, shuffle from random import random, shuffle
@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript):
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None) self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
from keras.callbacks import ModelCheckpoint from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor, checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor,
save_best_only=args.save_best) save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch') epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
self.epoch = epoch_fiti.read().read(0, int) self.epoch = epoch_fiti.read().read(0, int)
@ -224,16 +225,16 @@ class TrainGeneratedScript(BaseScript):
def run(self): def run(self):
"""Train the model on randomly generated batches""" """Train the model on randomly generated batches"""
_, test_data = self.data.load(train=False, test=True) _, test_data = self.data.load(train=True, test=True)
try: try:
self.model.fit_generator( self.model.fit(
self.samples_to_batches(self.generate_samples(), self.args.batch_size), self.samples_to_batches(self.generate_samples(), self.args.batch_size),
steps_per_epoch=self.args.steps_per_epoch, steps_per_epoch=self.args.steps_per_epoch,
epochs=self.epoch + self.args.epochs, validation_data=test_data, epochs=self.epoch + self.args.epochs, validation_data=test_data,
callbacks=self.callbacks, initial_epoch=self.epoch callbacks=self.callbacks, initial_epoch=self.epoch
) )
finally: finally:
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original rename(self.args.model + '.h5', self.args.model) # Rename with original
save_params(self.args.model) save_params(self.args.model)