172 lines
5.9 KiB
Python
172 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2019 Mycroft AI Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from fitipy import Fitipy
|
|
from tensorflow.keras.callbacks import LambdaCallback
|
|
from os.path import splitext, isfile
|
|
from prettyparse import Usage
|
|
from typing import Any, Tuple
|
|
|
|
from precise_lite.model import create_model, ModelParams
|
|
from precise_lite.params import inject_params, save_params
|
|
from precise_lite.scripts.base_script import BaseScript
|
|
from precise_lite.train_data import TrainData
|
|
from precise_lite.util import calc_sample_hash
|
|
|
|
|
|
class TrainScript(BaseScript):
|
|
usage = Usage('''
|
|
Train a new model on a dataset
|
|
|
|
:model str
|
|
Keras model file (.net) to load from and save to
|
|
|
|
:-sf --samples-file str -
|
|
Loads subset of data from the provided json file
|
|
generated with precise_lite-train-sampled
|
|
|
|
:-is --invert-samples
|
|
Loads subset of data not inside --samples-file
|
|
|
|
:-e --epochs int 10
|
|
Number of epochs to train model for
|
|
|
|
:-s --sensitivity float 0.2
|
|
Weighted loss bias. Higher values decrease increase positives
|
|
|
|
:-b --batch-size int 5000
|
|
Batch size for training
|
|
|
|
:-sb --save-best
|
|
Only save the model each epoch if its stats improve
|
|
|
|
:-nv --no-validation
|
|
Disable accuracy and validation calculation
|
|
to improve speed during training
|
|
|
|
:-mm --metric-monitor str loss
|
|
Metric used to determine when to save
|
|
|
|
:-em --extra-metrics
|
|
Add extra metrics during training
|
|
|
|
:-f --freeze-till int 0
|
|
Freeze all weights up to this index (non-inclusive).
|
|
Can be negative to wrap from end
|
|
|
|
...
|
|
''') | TrainData.usage
|
|
|
|
def __init__(self, args):
|
|
super().__init__(args)
|
|
|
|
if args.invert_samples and not args.samples_file:
|
|
raise ValueError('You must specify --samples-file when using --invert-samples')
|
|
if args.samples_file and not isfile(args.samples_file):
|
|
raise ValueError('No such file: ' + (args.invert_samples or args.samples_file))
|
|
if not 0.0 <= args.sensitivity <= 1.0:
|
|
raise ValueError('sensitivity must be between 0.0 and 1.0')
|
|
|
|
inject_params(args.model)
|
|
save_params(args.model)
|
|
params = ModelParams(skip_acc=args.no_validation, extra_metrics=args.extra_metrics,
|
|
loss_bias=1.0 - args.sensitivity, freeze_till=args.freeze_till)
|
|
self.model = create_model(args.model, params)
|
|
self.train, self.test = self.load_data(self.args)
|
|
|
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
|
checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
|
save_best_only=args.save_best)
|
|
epoch_fiti = Fitipy(splitext(args.model)[0] + '.epoch')
|
|
self.epoch = epoch_fiti.read().read(0, int)
|
|
|
|
def on_epoch_end(_a, _b):
|
|
self.epoch += 1
|
|
epoch_fiti.write().write(self.epoch, str)
|
|
|
|
self.model_base = splitext(self.args.model)[0]
|
|
|
|
if args.samples_file:
|
|
self.samples, self.hash_to_ind = self.load_sample_data(args.samples_file, self.train)
|
|
else:
|
|
self.samples = set()
|
|
self.hash_to_ind = {}
|
|
|
|
self.callbacks = [
|
|
checkpoint,
|
|
LambdaCallback(on_epoch_end=on_epoch_end),
|
|
]
|
|
|
|
@staticmethod
|
|
def load_sample_data(filename, train_data) -> Tuple[set, dict]:
|
|
samples = Fitipy(filename).read().set()
|
|
hash_to_ind = {
|
|
calc_sample_hash(inp, outp): ind
|
|
for ind, (inp, outp) in enumerate(zip(*train_data))
|
|
}
|
|
for hsh in list(samples):
|
|
if hsh not in hash_to_ind:
|
|
print('Missing hash:', hsh)
|
|
samples.remove(hsh)
|
|
return samples, hash_to_ind
|
|
|
|
@staticmethod
|
|
def load_data(args: Any) -> Tuple[tuple, tuple]:
|
|
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
|
print('Data:', data)
|
|
train, test = data.load(True, not args.no_validation)
|
|
|
|
print('Inputs shape:', train[0].shape)
|
|
print('Outputs shape:', train[1].shape)
|
|
|
|
if test:
|
|
print('Test inputs shape:', test[0].shape)
|
|
print('Test outputs shape:', test[1].shape)
|
|
|
|
if 0 in train[0].shape or 0 in train[1].shape:
|
|
print('Not enough data to train')
|
|
exit(1)
|
|
|
|
return train, test
|
|
|
|
@property
|
|
def sampled_data(self):
|
|
"""Returns (train_inputs, train_outputs)"""
|
|
if self.args.samples_file:
|
|
if self.args.invert_samples:
|
|
chosen_samples = set(self.hash_to_ind) - self.samples
|
|
else:
|
|
chosen_samples = self.samples
|
|
selected_indices = [self.hash_to_ind[h] for h in chosen_samples]
|
|
return self.train[0][selected_indices], self.train[1][selected_indices]
|
|
else:
|
|
return self.train[0], self.train[1]
|
|
|
|
def run(self):
|
|
self.model.summary()
|
|
train_inputs, train_outputs = self.sampled_data
|
|
self.model.fit(
|
|
train_inputs, train_outputs, self.args.batch_size,
|
|
self.epoch + self.args.epochs, validation_data=self.test,
|
|
initial_epoch=self.epoch, callbacks=self.callbacks,
|
|
use_multiprocessing=True, validation_freq=5,
|
|
verbose=1
|
|
)
|
|
|
|
|
|
main = TrainScript.run_main
|
|
|
|
if __name__ == '__main__':
|
|
main()
|