diff --git a/precise/scripts/train_generated.py b/precise/scripts/train_generated.py index d5453f9..aa98118 100644 --- a/precise/scripts/train_generated.py +++ b/precise/scripts/train_generated.py @@ -19,6 +19,7 @@ import numpy as np from contextlib import suppress from fitipy import Fitipy from tensorflow.keras.callbacks import LambdaCallback +from os import rename from os.path import splitext, join, basename from prettyparse import Usage from random import random, shuffle @@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript): self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None) from tensorflow.keras.callbacks import ModelCheckpoint - checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor, + checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor, save_best_only=args.save_best) epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch') self.epoch = epoch_fiti.read().read(0, int) @@ -224,16 +225,17 @@ class TrainGeneratedScript(BaseScript): def run(self): """Train the model on randomly generated batches""" - _, test_data = self.data.load(train=False, test=True) + _, test_data = self.data.load(train=True, test=True) try: - self.model.fit_generator( + self.model.fit( self.samples_to_batches(self.generate_samples(), self.args.batch_size), steps_per_epoch=self.args.steps_per_epoch, epochs=self.epoch + self.args.epochs, validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch ) finally: - self.model.save(self.args.model, save_format='h5') + self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format + rename(self.args.model + '.h5', self.args.model) # Rename with original save_params(self.args.model) diff --git a/precise/scripts/train_incremental.py b/precise/scripts/train_incremental.py index ef2082d..67d744e 100644 --- a/precise/scripts/train_incremental.py +++ b/precise/scripts/train_incremental.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -from os import makedirs +from os import makedirs, rename from os.path import basename, splitext, isfile, join from prettyparse import Usage from random import random @@ -107,7 +107,8 @@ class TrainIncrementalScript(TrainScript): validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch ) finally: - self.listener.runner.model.save(self.args.model,save_format='h5') + self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format + rename(self.args.model + '.h5', self.args.model) # Rename with original def train_on_audio(self, fn: str): """Run through a single audio file"""