Allows for generator training.
This commit is contained in:
parent
e2a3c2dc58
commit
3baf5da248
@ -19,6 +19,7 @@ import numpy as np
|
||||
from contextlib import suppress
|
||||
from fitipy import Fitipy
|
||||
from keras.callbacks import LambdaCallback
|
||||
from os import rename
|
||||
from os.path import splitext, join, basename
|
||||
from prettyparse import Usage
|
||||
from random import random, shuffle
|
||||
@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript):
|
||||
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
|
||||
|
||||
from keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
self.epoch = epoch_fiti.read().read(0, int)
|
||||
@ -224,16 +225,16 @@ class TrainGeneratedScript(BaseScript):
|
||||
|
||||
def run(self):
|
||||
"""Train the model on randomly generated batches"""
|
||||
_, test_data = self.data.load(train=False, test=True)
|
||||
_, test_data = self.data.load(train=True, test=True)
|
||||
try:
|
||||
self.model.fit_generator(
|
||||
self.model.fit(
|
||||
self.samples_to_batches(self.generate_samples(), self.args.batch_size),
|
||||
steps_per_epoch=self.args.steps_per_epoch,
|
||||
epochs=self.epoch + self.args.epochs, validation_data=test_data,
|
||||
callbacks=self.callbacks, initial_epoch=self.epoch
|
||||
)
|
||||
finally:
|
||||
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
|
||||
self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
|
||||
rename(self.args.model + '.h5', self.args.model) # Rename with original
|
||||
save_params(self.args.model)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user