Merge pull request #2 from andreselizondo-adestech/incremental_training

Adds Incremental and Generator training
This commit is contained in:
andreselizondo-adestech 2020-08-19 14:31:58 -05:00 committed by GitHub
commit 1979f14b27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 6 deletions

View File

@ -19,6 +19,7 @@ import numpy as np
from contextlib import suppress
from fitipy import Fitipy
from tensorflow.keras.callbacks import LambdaCallback
from os import rename
from os.path import splitext, join, basename
from prettyparse import Usage
from random import random, shuffle
@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript):
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor,
save_best_only=args.save_best)
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
self.epoch = epoch_fiti.read().read(0, int)
@ -224,16 +225,17 @@ class TrainGeneratedScript(BaseScript):
def run(self):
"""Train the model on randomly generated batches"""
_, test_data = self.data.load(train=False, test=True)
_, test_data = self.data.load(train=True, test=True)
try:
self.model.fit_generator(
self.model.fit(
self.samples_to_batches(self.generate_samples(), self.args.batch_size),
steps_per_epoch=self.args.steps_per_epoch,
epochs=self.epoch + self.args.epochs, validation_data=test_data,
callbacks=self.callbacks, initial_epoch=self.epoch
)
finally:
self.model.save(self.args.model, save_format='h5')
self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original
save_params(self.args.model)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from os import makedirs
from os import makedirs, rename
from os.path import basename, splitext, isfile, join
from prettyparse import Usage
from random import random
@ -107,7 +107,8 @@ class TrainIncrementalScript(TrainScript):
validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch
)
finally:
self.listener.runner.model.save(self.args.model,save_format='h5')
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original
def train_on_audio(self, fn: str):
"""Run through a single audio file"""