Merge pull request #2 from andreselizondo-adestech/incremental_training
Adds Incremental and Generator training
This commit is contained in:
commit
1979f14b27
@ -19,6 +19,7 @@ import numpy as np
|
||||
from contextlib import suppress
|
||||
from fitipy import Fitipy
|
||||
from tensorflow.keras.callbacks import LambdaCallback
|
||||
from os import rename
|
||||
from os.path import splitext, join, basename
|
||||
from prettyparse import Usage
|
||||
from random import random, shuffle
|
||||
@ -91,7 +92,7 @@ class TrainGeneratedScript(BaseScript):
|
||||
self.listener = Listener('', args.chunk_size, runner_cls=lambda x: None)
|
||||
|
||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
checkpoint = ModelCheckpoint(args.model + '.pb', monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
||||
self.epoch = epoch_fiti.read().read(0, int)
|
||||
@ -224,16 +225,17 @@ class TrainGeneratedScript(BaseScript):
|
||||
|
||||
def run(self):
|
||||
"""Train the model on randomly generated batches"""
|
||||
_, test_data = self.data.load(train=False, test=True)
|
||||
_, test_data = self.data.load(train=True, test=True)
|
||||
try:
|
||||
self.model.fit_generator(
|
||||
self.model.fit(
|
||||
self.samples_to_batches(self.generate_samples(), self.args.batch_size),
|
||||
steps_per_epoch=self.args.steps_per_epoch,
|
||||
epochs=self.epoch + self.args.epochs, validation_data=test_data,
|
||||
callbacks=self.callbacks, initial_epoch=self.epoch
|
||||
)
|
||||
finally:
|
||||
self.model.save(self.args.model, save_format='h5')
|
||||
self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
|
||||
rename(self.args.model + '.h5', self.args.model) # Rename with original
|
||||
save_params(self.args.model)
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from os import makedirs
|
||||
from os import makedirs, rename
|
||||
from os.path import basename, splitext, isfile, join
|
||||
from prettyparse import Usage
|
||||
from random import random
|
||||
@ -107,7 +107,8 @@ class TrainIncrementalScript(TrainScript):
|
||||
validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch
|
||||
)
|
||||
finally:
|
||||
self.listener.runner.model.save(self.args.model,save_format='h5')
|
||||
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
|
||||
rename(self.args.model + '.h5', self.args.model) # Rename with original
|
||||
|
||||
def train_on_audio(self, fn: str):
|
||||
"""Run through a single audio file"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user