precise-lite-amd64aarch64/precise_lite/scripts/calc_threshold.py

91 lines
2.9 KiB
Python

#!/usr/bin/env python3
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from math import sqrt
from os.path import basename, splitext
from prettyparse import Usage
from precise_lite.params import inject_params, save_params
from precise_lite.scripts.base_script import BaseScript
from precise_lite.stats import Stats
class CalcThresholdScript(BaseScript):
usage = Usage('''
Update the threshold values of a model for a dataset.
This makes the sensitivity more accurate and linear
:model str
Either Keras (.net) or TensorFlow (.pb) model to adjust
:input_file str
Input stats file that was outputted from precise_lite-graph
:-k --model-key str -
Custom model name to use from the stats.json
:-s --smoothing float 1.2
Amount of extra smoothing to apply
:-c --center float 0.2
Decoded threshold that is mapped to 0.5. Proportion of
false negatives at sensitivity=0.5
''')
def __init__(self, args):
super().__init__(args)
def run(self):
args = self.args
import numpy as np
model_data = {
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
}
model_name = args.model_key or basename(splitext(args.model)[0])
if model_name not in model_data:
print("Could not find model '{}' in saved models in stats file: {}".format(model_name, list(model_data)))
raise SystemExit(1)
stats = model_data[model_name]
save_spots = (stats.outputs != 0) & (stats.outputs != 1)
if save_spots.sum() == 0:
print('No data (or all NaN)')
return
stats.outputs = stats.outputs[save_spots]
stats.targets = stats.targets[save_spots]
inv = -np.log(1 / stats.outputs - 1)
pos = np.extract(stats.targets > 0.5, inv)
pos_mu = pos.mean().item()
pos_std = sqrt(np.mean((pos - pos_mu) ** 2)) * args.smoothing
print('Peak: {:.2f} mu, {:.2f} std'.format(pos_mu, pos_std))
pr = inject_params(args.model)
pr.__dict__.update(threshold_config=(
(pos_mu, pos_std),
))
save_params(args.model)
print('Saved params to {}.params'.format(args.model))
main = CalcThresholdScript.run_main
if __name__ == '__main__':
main()