Allows for generator training.

This commit is contained in:
Andres Elizondo 2020-07-08 12:04:50 -05:00
parent e2a3c2dc58
commit 3baf5da248

View File

@ -19,6 +19,7 @@ import numpy as np
from contextlib import suppress
from fitipy import Fitipy
from keras.callbacks import LambdaCallback
from os import rename
from os.path import splitext, join, basename
from prettyparse import Usage
from random import random, shuffle
@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript):
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
self.epoch = epoch_fiti.read().read(0, int)
@ -224,16 +225,16 @@ class TrainGeneratedScript(BaseScript):
def run(self):
"""Train the model on randomly generated batches"""
_, test_data = self.data.load(train=False, test=True)
_, test_data = self.data.load(train=True, test=True)
try:
self.model.fit_generator(
self.model.fit(
self.samples_to_batches(self.generate_samples(), self.args.batch_size),
steps_per_epoch=self.args.steps_per_epoch,
epochs=self.epoch + self.args.epochs, validation_data=test_data,
callbacks=self.callbacks, initial_epoch=self.epoch
)
finally:
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original
save_params(self.args.model)