166 lines
5.7 KiB
Python
166 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2019 Mycroft AI Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License
|
|
import numpy as np
|
|
from functools import partial
|
|
from os.path import basename, splitext
|
|
from prettyparse import Usage
|
|
from typing import Callable, Tuple
|
|
|
|
from precise_lite.network_runner import Listener
|
|
from precise_lite.params import inject_params, pr
|
|
from precise_lite.scripts.base_script import BaseScript
|
|
from precise_lite.stats import Stats
|
|
from precise_lite.threshold_decoder import ThresholdDecoder
|
|
from precise_lite.train_data import TrainData
|
|
|
|
|
|
def get_thresholds(points=100, power=3) -> list:
|
|
"""Run a function with a series of thresholds between 0 and 1"""
|
|
return [(i / (points + 1)) ** power for i in range(1, points + 1)]
|
|
|
|
|
|
class CachedDataLoader:
|
|
"""
|
|
Class for reloading train data every time the params change
|
|
|
|
Args:
|
|
loader: Function that loads the train data (something that calls TrainData.load)
|
|
"""
|
|
|
|
def __init__(self, loader: Callable):
|
|
self.prev_cache = None
|
|
self.data = None
|
|
self.loader = loader
|
|
|
|
def load_for(self, model: str) -> Tuple[list, list]:
|
|
"""Injects the model parameters, reloading if they changed, and returning the data"""
|
|
inject_params(model)
|
|
if self.prev_cache != pr.vectorization_md5_hash():
|
|
self.prev_cache = pr.vectorization_md5_hash()
|
|
self.data = self.loader()
|
|
return self.data
|
|
|
|
|
|
def load_plt():
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
return plt
|
|
except ImportError:
|
|
print('Please install matplotlib first')
|
|
raise SystemExit(2)
|
|
|
|
|
|
def calc_stats(model_files, loader, use_train, filenames):
|
|
model_data = {}
|
|
for model in model_files:
|
|
train, test = loader.load_for(model)
|
|
inputs, targets = train if use_train else test
|
|
print('Running network...')
|
|
predictions = Listener.find_runner(model)(model).predict(inputs)
|
|
print(inputs.shape, targets.shape)
|
|
|
|
print('Generating statistics...')
|
|
stats = Stats(predictions, targets, filenames)
|
|
print('\n' + stats.counts_str() + '\n\n' + stats.summary_str() + '\n')
|
|
|
|
model_name = basename(splitext(model)[0])
|
|
model_data[model_name] = stats
|
|
return model_data
|
|
|
|
|
|
class GraphScript(BaseScript):
|
|
usage = Usage('''
|
|
Show ROC curves for a series of models
|
|
|
|
...
|
|
|
|
:-t --use-train
|
|
Evaluate training data instead of test data
|
|
|
|
:-nf --no-filenames
|
|
Don't print out the names of files that failed
|
|
|
|
:-r --resolution int 100
|
|
Number of points to generate
|
|
|
|
:-p --power float 3.0
|
|
Power of point distribution
|
|
|
|
:-l --labels
|
|
Print labels attached to each point
|
|
|
|
:-o --output-file str -
|
|
File to write data instead of displaying it
|
|
|
|
:-i --input-file str -
|
|
File to read data from and visualize
|
|
|
|
...
|
|
''')
|
|
usage.add_argument('models', nargs='*', help='Either Keras (.net) or TensorFlow (.pb) models to test')
|
|
usage |= TrainData.usage
|
|
|
|
def __init__(self, args):
|
|
super().__init__(args)
|
|
if not args.models and not args.input_file and args.folder:
|
|
args.input_file = args.folder
|
|
if bool(args.models) == bool(args.input_file):
|
|
raise ValueError('Please specify either a list of models or an input file')
|
|
|
|
if not args.output_file:
|
|
load_plt() # Error early if matplotlib not installed
|
|
|
|
def run(self):
|
|
args = self.args
|
|
if args.models:
|
|
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
|
print('Data:', data)
|
|
filenames = sum(data.train_files if args.use_train else data.test_files, [])
|
|
loader = CachedDataLoader(partial(
|
|
data.load, args.use_train, not args.use_train, shuffle=False
|
|
))
|
|
model_data = calc_stats(args.models, loader, args.use_train, filenames)
|
|
else:
|
|
model_data = {
|
|
name: Stats.from_np_dict(data) for name, data in np.load(args.input_file)['data'].item().items()
|
|
}
|
|
for name, stats in model_data.items():
|
|
print('=== {} ===\n{}\n\n{}\n'.format(name, stats.counts_str(), stats.summary_str()))
|
|
|
|
if args.output_file:
|
|
np.savez(args.output_file, data={name: stats.to_np_dict() for name, stats in model_data.items()})
|
|
else:
|
|
plt = load_plt()
|
|
decoder = ThresholdDecoder(pr.threshold_config, pr.threshold_center)
|
|
thresholds = [decoder.encode(i) for i in np.linspace(0.0, 1.0, args.resolution)[1:-1]]
|
|
for model_name, stats in model_data.items():
|
|
x = [stats.false_positives(i) for i in thresholds]
|
|
y = [stats.false_negatives(i) for i in thresholds]
|
|
plt.plot(x, y, marker='x', linestyle='-', label=model_name)
|
|
if args.labels:
|
|
for x, y, threshold in zip(x, y, thresholds):
|
|
plt.annotate('{:.4f}'.format(threshold), (x, y))
|
|
|
|
plt.legend()
|
|
plt.xlabel('False Positives')
|
|
plt.ylabel('False Negatives')
|
|
plt.show()
|
|
|
|
|
|
main = GraphScript.run_main
|
|
|
|
if __name__ == '__main__':
|
|
main()
|