precise-lite-amd64aarch64/precise_lite/scripts/graph.py

166 lines
5.7 KiB
Python

#!/usr/bin/env python3
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from functools import partial
from os.path import basename, splitext
from prettyparse import Usage
from typing import Callable, Tuple
from precise_lite.network_runner import Listener
from precise_lite.params import inject_params, pr
from precise_lite.scripts.base_script import BaseScript
from precise_lite.stats import Stats
from precise_lite.threshold_decoder import ThresholdDecoder
from precise_lite.train_data import TrainData
def get_thresholds(points=100, power=3) -> list:
"""Run a function with a series of thresholds between 0 and 1"""
return [(i / (points + 1)) ** power for i in range(1, points + 1)]
class CachedDataLoader:
"""
Class for reloading train data every time the params change
Args:
loader: Function that loads the train data (something that calls TrainData.load)
"""
def __init__(self, loader: Callable):
self.prev_cache = None
self.data = None
self.loader = loader
def load_for(self, model: str) -> Tuple[list, list]:
"""Injects the model parameters, reloading if they changed, and returning the data"""
inject_params(model)
if self.prev_cache != pr.vectorization_md5_hash():
self.prev_cache = pr.vectorization_md5_hash()
self.data = self.loader()
return self.data
def load_plt():
try:
import matplotlib.pyplot as plt
return plt
except ImportError:
print('Please install matplotlib first')
raise SystemExit(2)
def calc_stats(model_files, loader, use_train, filenames):
model_data = {}
for model in model_files:
train, test = loader.load_for(model)
inputs, targets = train if use_train else test
print('Running network...')
predictions = Listener.find_runner(model)(model).predict(inputs)
print(inputs.shape, targets.shape)
print('Generating statistics...')
stats = Stats(predictions, targets, filenames)
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
model_name = basename(splitext(model)[0])
model_data[model_name] = stats
return model_data
class GraphScript(BaseScript):
usage = Usage('''
Show ROC curves for a series of models
...
:-t --use-train
Evaluate training data instead of test data
:-nf --no-filenames
Don't print out the names of files that failed
:-r --resolution int 100
Number of points to generate
:-p --power float 3.0
Power of point distribution
:-l --labels
Print labels attached to each point
:-o --output-file str -
File to write data instead of displaying it
:-i --input-file str -
File to read data from and visualize
...
''')
usage.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
usage |= TrainData.usage
def __init__(self, args):
super().__init__(args)
if not args.models and not args.input_file and args.folder:
args.input_file = args.folder
if bool(args.models) == bool(args.input_file):
raise ValueError('Please specify either a list of models or an input file')
if not args.output_file:
load_plt() # Error early if matplotlib not installed
def run(self):
args = self.args
if args.models:
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
filenames = sum(data.train_files if args.use_train else data.test_files, [])
loader = CachedDataLoader(partial(
data.load, args.use_train, not args.use_train, shuffle=False
))
model_data = calc_stats(args.models, loader, args.use_train, filenames)
else:
model_data = {
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
}
for name, stats in model_data.items():
print('=== {} ===\n{}\n\n{}\n'.format(name, stats.counts_str(), stats.summary_str()))
if args.output_file:
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
else:
plt = load_plt()
decoder = ThresholdDecoder(pr.threshold_config, pr.threshold_center)
thresholds = [decoder.encode(i) for i in np.linspace(0.0, 1.0, args.resolution)[1:-1]]
for model_name, stats in model_data.items():
x = [stats.false_positives(i) for i in thresholds]
y = [stats.false_negatives(i) for i in thresholds]
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
if args.labels:
for x, y, threshold in zip(x, y, thresholds):
plt.annotate('{:.4f}'.format(threshold), (x, y))
plt.legend()
plt.xlabel('False Positives')
plt.ylabel('False Negatives')
plt.show()
main = GraphScript.run_main
if __name__ == '__main__':
main()