precise-lite-amd64aarch64/precise_lite/scripts/train_incremental.py

161 lines
6.0 KiB
Python

#!/usr/bin/env python3
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from os import makedirs, rename
from os.path import basename, splitext, isfile, join
from prettyparse import Usage
from random import random
from typing import *
from precise_lite.model import create_model, ModelParams
from precise_lite.network_runner import Listener, KerasRunner
from precise_lite.params import pr
from precise_lite.scripts.train import TrainScript
from precise_lite.train_data import TrainData
from precise_lite.util import load_audio, save_audio, glob_all, chunk_audio
def load_trained_fns(model_name: str) -> list:
progress_file = model_name.replace('.net', '') + '.trained.txt'
if isfile(progress_file):
print('Starting from saved position in', progress_file)
with open(progress_file, 'rb') as f:
return f.read().decode('utf8', 'surrogatepass').split('\n')
return []
def save_trained_fns(trained_fns: list, model_name: str):
with open(model_name.replace('.net', '') + '.trained.txt', 'wb') as f:
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))
class TrainIncrementalScript(TrainScript):
usage = Usage('''
Train a model to inhibit activation by
marking false activations and retraining
:-e --epochs int 1
Number of epochs to train before continuing evaluation
:-ds --delay-samples int 10
Number of false activations to save before re-training
:-c --chunk-size int 2048
Number of samples between testing the neural network
:-r --random-data-folder str data/random
Folder with properly encoded wav files of
random audio that should not cause an activation
:-th --threshold float 0.5
Network output to be considered activated
...
''') | TrainScript.usage
def __init__(self, args):
super().__init__(args)
for i in (
join(self.args.folder, 'not-wake-word', 'generated'),
join(self.args.folder, 'test', 'not-wake-word', 'generated')
):
makedirs(i, exist_ok=True)
self.trained_fns = load_trained_fns(self.args.model)
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
params = ModelParams(
skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics,
loss_bias=1.0 - self.args.sensitivity
)
model = create_model(self.args.model, params)
self.listener = Listener(self.args.model, self.args.chunk_size, runner_cls=KerasRunner)
self.listener.runner = KerasRunner(self.args.model)
self.listener.runner.model = model
self.samples_since_train = 0
@staticmethod
def load_data(args: Any):
data = TrainData.from_tags(args.tags_file, args.tags_folder)
return data.load(True, not args.no_validation)
def retrain(self):
"""Train for a session, pulling in any new data from the filesystem"""
folder = TrainData.from_folder(self.args.folder)
train_data, test_data = folder.load(True, not self.args.no_validation)
train_data = TrainData.merge(train_data, self.sampled_data)
test_data = TrainData.merge(test_data, self.test)
train_inputs, train_outputs = train_data
print()
try:
self.listener.runner.model.fit(
train_inputs, train_outputs, self.args.batch_size, self.epoch + self.args.epochs,
validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch
)
finally:
self.listener.runner.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
rename(self.args.model + '.h5', self.args.model) # Rename with original
def train_on_audio(self, fn: str):
"""Run through a single audio file"""
save_test = random() > 0.8
audio = load_audio(fn)
num_chunks = len(audio) // self.args.chunk_size
self.listener.clear()
for i, chunk in enumerate(chunk_audio(audio, self.args.chunk_size)):
print('\r' + str(i * 100. / num_chunks) + '%', end='', flush=True)
self.audio_buffer = np.concatenate((self.audio_buffer[len(chunk):], chunk))
conf = self.listener.update(chunk)
if conf > self.args.threshold:
self.samples_since_train += 1
name = splitext(basename(fn))[0] + '-' + str(i) + '.wav'
name = join(self.args.folder, 'test' if save_test else '', 'not-wake-word',
'generated', name)
save_audio(name, self.audio_buffer)
print()
print('Saved to:', name)
if not save_test and self.samples_since_train >= self.args.delay_samples and \
self.args.epochs > 0:
self.samples_since_train = 0
self.retrain()
def run(self):
"""
Begin reading through audio files, saving false
activations and retraining when necessary
"""
for fn in glob_all(self.args.random_data_folder, '*.wav'):
if fn in self.trained_fns:
print('Skipping ' + fn + '...')
continue
print('Starting file ' + fn + '...')
self.train_on_audio(fn)
print('\r100% ')
self.trained_fns.append(fn)
save_trained_fns(self.trained_fns, self.args.model)
main = TrainIncrementalScript.run_main
if __name__ == '__main__':
main()