tflite edition #0
This commit is contained in:
parent
279bbc59d3
commit
2dea7953e8
@ -1,4 +1,4 @@
|
||||
# Mycroft Precise
|
||||
# Mycroft Precise-Lite
|
||||
|
||||
*A lightweight, simple-to-use, RNN wake word listener.*
|
||||
|
||||
|
61
build.sh
61
build.sh
@ -1,61 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
|
||||
tar_name() {
|
||||
local tar_prefix=$1
|
||||
echo "${tar_prefix}_$(precise-engine --version 2>&1)_$(uname -m).tar.gz"
|
||||
}
|
||||
|
||||
replace() {
|
||||
local pattern=$1
|
||||
local replacement=$2
|
||||
sed -e "s/$pattern/$replacement/gm"
|
||||
}
|
||||
|
||||
package_scripts() {
|
||||
local tar_prefix=$1
|
||||
local combined_folder=$2
|
||||
local scripts=$3
|
||||
local train_libs=$4
|
||||
local completed_file="dist/completed_$combined_folder.txt"
|
||||
|
||||
if ! [ -f "$completed_file" ]; then
|
||||
rm -rf "dist/$combined_folder"
|
||||
fi
|
||||
mkdir -p "dist/$combined_folder"
|
||||
|
||||
for script in $scripts; do
|
||||
exe=precise-$(echo "$script" | tr '_' '-')
|
||||
if [ -f "$completed_file" ] && grep -qF "$exe" "$completed_file"; then
|
||||
continue
|
||||
fi
|
||||
tmp_name=$(mktemp).spec
|
||||
cat "precise.template.spec" | replace "%%SCRIPT%%" "$script" | replace "%%TRAIN_LIBS%%" "$train_libs" > "$tmp_name"
|
||||
pyinstaller -y "$tmp_name"
|
||||
if [ "$exe" != "$combined_folder" ]; then
|
||||
cp -R dist/$exe/* "dist/$combined_folder"
|
||||
rm -rf "dist/$exe" "build/$exe"
|
||||
fi
|
||||
echo "$exe" >> "$completed_file"
|
||||
done
|
||||
|
||||
out_name=$(tar_name "$tar_prefix")
|
||||
cd dist
|
||||
tar czvf "$out_name" "$combined_folder"
|
||||
md5sum "$out_name" > "$out_name.md5"
|
||||
cd ..
|
||||
}
|
||||
|
||||
set -eE
|
||||
|
||||
./setup.sh
|
||||
source .venv/bin/activate
|
||||
pip install pyinstaller
|
||||
|
||||
all_scripts=$(grep -oP '(?<=precise.scripts.)[a-z_]+' setup.py)
|
||||
package_scripts "precise-all" "precise" "$all_scripts" True
|
||||
package_scripts "precise-engine" "precise-engine" "engine" False
|
||||
|
||||
tar_1=dist/$(tar_name precise-all)
|
||||
tar_2=dist/$(tar_name precise-engine)
|
||||
echo "Wrote to $tar_1 and $tar_2"
|
60
export.sh
60
export.sh
@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
if ! [[ "$1" =~ .*\.net$ ]] || ! [ -f "$1" ] || ! [[ $# =~ [2-3] ]]; then
|
||||
echo "Usage: $0 <model>.net GITHUB_REPO [BRANCH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
model_file=$(readlink -f "$1")
|
||||
repo=$2
|
||||
branch=${3-master}
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
set -e
|
||||
cache=.cache/precise-data/${repo//\//.}.${branch//\//.}
|
||||
[ -d "$cache" ] || git clone "$repo" "$cache" -b "$branch" --single-branch
|
||||
|
||||
pushd "$cache"
|
||||
git fetch
|
||||
git checkout "$branch"
|
||||
git reset --hard "origin/$branch"
|
||||
popd
|
||||
|
||||
source .venv/bin/activate
|
||||
model_name=$(basename "${1%%.net}")
|
||||
precise-convert "$model_file" -o "$cache/$model_name.pb"
|
||||
|
||||
pushd "$cache"
|
||||
tar cvf "$model_name.tar.gz" "$model_name.pb" "$model_name.pb.params"
|
||||
md5sum "$model_name.tar.gz" > "$model_name.tar.gz.md5"
|
||||
rm -f "$model_name.pb" "$model_name.pb.params" "$model_name.pbtxt"
|
||||
git reset
|
||||
git add "$model_name.tar.gz" "$model_name.tar.gz.md5"
|
||||
|
||||
echo
|
||||
ls
|
||||
git status
|
||||
|
||||
read -p "Uploading $model_name model to branch $branch on repo $repo. Confirm? (y/N) " answer
|
||||
if [ "$answer" != "y" ] && [ "$answer" != "Y" ]; then
|
||||
echo "Aborted."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
git commit -m "Update $model_name"
|
||||
git push
|
||||
popd
|
||||
|
@ -1,75 +0,0 @@
|
||||
# -*- mode: python -*-
|
||||
block_cipher = None
|
||||
|
||||
from glob import iglob
|
||||
from os.path import basename, dirname, abspath
|
||||
import os
|
||||
import fnmatch
|
||||
|
||||
script_name = '%%SCRIPT%%'
|
||||
train_libs = %%TRAIN_LIBS%%
|
||||
strip = True
|
||||
site_packages = '.venv/lib/python3.6/site-packages/'
|
||||
hidden_imports = ['prettyparse', 'speechpy']
|
||||
binaries = []
|
||||
|
||||
|
||||
def recursive_glob(treeroot, pattern):
|
||||
results = []
|
||||
for base, dirs, files in os.walk(treeroot):
|
||||
goodfiles = fnmatch.filter(files, pattern)
|
||||
results.extend(os.path.join(base, f) for f in goodfiles)
|
||||
return results
|
||||
|
||||
|
||||
if train_libs:
|
||||
binaries = [
|
||||
(abspath(i), dirname(i.replace(site_packages, '')))
|
||||
for i in recursive_glob(site_packages + "tensorflow/", "*.so")
|
||||
]
|
||||
hidden_imports += ['h5py']
|
||||
|
||||
a = Analysis(
|
||||
[abspath('precise/scripts/{}.py'.format(script_name))],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=[],
|
||||
hiddenimports=hidden_imports,
|
||||
hookspath=[],
|
||||
runtime_hooks=[],
|
||||
excludes=['PySide', 'PyQt4', 'PyQt5', 'matplotlib'],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher
|
||||
)
|
||||
|
||||
for i in range(len(a.binaries)):
|
||||
dest, origin, kind = a.binaries[i]
|
||||
if '_pywrap_tensorflow_internal' in dest:
|
||||
a.binaries[i] = ('tensorflow.python.' + dest, origin, kind)
|
||||
|
||||
pyz = PYZ(
|
||||
a.pure, a.zipped_data,
|
||||
cipher=block_cipher
|
||||
)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
exclude_binaries=True,
|
||||
name='precise-{}'.format(script_name.replace('_', '-')),
|
||||
debug=False,
|
||||
strip=strip,
|
||||
upx=True,
|
||||
console=True
|
||||
)
|
||||
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
strip=strip,
|
||||
upx=True,
|
||||
name='precise-{}'.format(script_name.replace('_', '-'))
|
||||
)
|
@ -1 +0,0 @@
|
||||
__version__ = '0.3.0'
|
@ -1,69 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from typing import *
|
||||
from typing import BinaryIO
|
||||
|
||||
from precise.params import pr
|
||||
from precise.util import audio_to_buffer
|
||||
|
||||
|
||||
class PocketsphinxListener:
|
||||
"""Pocketsphinx listener implementation used for comparison with Precise"""
|
||||
|
||||
def __init__(self, key_phrase, dict_file, hmm_folder, threshold=1e-90, chunk_size=-1):
|
||||
from pocketsphinx import Decoder
|
||||
config = Decoder.default_config()
|
||||
config.set_string('-hmm', hmm_folder)
|
||||
config.set_string('-dict', dict_file)
|
||||
config.set_string('-keyphrase', key_phrase)
|
||||
config.set_float('-kws_threshold', float(threshold))
|
||||
config.set_float('-samprate', 16000)
|
||||
config.set_int('-nfft', 2048)
|
||||
config.set_string('-logfn', '/dev/null')
|
||||
self.key_phrase = key_phrase
|
||||
self.buffer = b'\0' * pr.sample_depth * pr.buffer_samples
|
||||
self.pr = pr
|
||||
self.read_size = -1 if chunk_size == -1 else pr.sample_depth * chunk_size
|
||||
|
||||
try:
|
||||
self.decoder = Decoder(config)
|
||||
except RuntimeError:
|
||||
options = dict(key_phrase=key_phrase, dict_file=dict_file,
|
||||
hmm_folder=hmm_folder, threshold=threshold)
|
||||
raise RuntimeError('Invalid Pocketsphinx options: ' + str(options))
|
||||
|
||||
def _transcribe(self, byte_data):
|
||||
self.decoder.start_utt()
|
||||
self.decoder.process_raw(byte_data, False, False)
|
||||
self.decoder.end_utt()
|
||||
return self.decoder.hyp()
|
||||
|
||||
def found_wake_word(self, frame_data):
|
||||
hyp = self._transcribe(frame_data + b'\0' * int(2 * 16000 * 0.01))
|
||||
return bool(hyp and self.key_phrase in hyp.hypstr.lower())
|
||||
|
||||
def update(self, stream: Union[BinaryIO, np.ndarray, bytes]) -> float:
|
||||
if isinstance(stream, np.ndarray):
|
||||
chunk = audio_to_buffer(stream)
|
||||
else:
|
||||
if isinstance(stream, (bytes, bytearray)):
|
||||
chunk = stream
|
||||
else:
|
||||
chunk = stream.read(self.read_size)
|
||||
if len(chunk) == 0:
|
||||
raise EOFError
|
||||
self.buffer = self.buffer[len(chunk):] + chunk
|
||||
return float(self.found_wake_word(self.buffer))
|
@ -1,67 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from precise_runner import PreciseRunner
|
||||
from precise_runner.runner import ListenerEngine
|
||||
from prettyparse import Usage
|
||||
from threading import Event
|
||||
|
||||
from precise.pocketsphinx.listener import PocketsphinxListener
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.util import activate_notify
|
||||
|
||||
|
||||
class PocketsphinxListenScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Run Pocketsphinx on microphone audio input
|
||||
|
||||
:key_phrase str
|
||||
Key phrase composed of words from dictionary
|
||||
|
||||
:dict_file str
|
||||
Filename of dictionary with word pronunciations
|
||||
|
||||
:hmm_folder str
|
||||
Folder containing hidden markov model
|
||||
|
||||
:-th --threshold str 1e-90
|
||||
Threshold for activations
|
||||
|
||||
:-c --chunk-size int 2048
|
||||
Samples between inferences
|
||||
''')
|
||||
|
||||
def run(self):
|
||||
def on_activation():
|
||||
activate_notify()
|
||||
|
||||
def on_prediction(conf):
|
||||
print('!' if conf > 0.5 else '.', end='', flush=True)
|
||||
|
||||
args = self.args
|
||||
runner = PreciseRunner(
|
||||
ListenerEngine(
|
||||
PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold, args.chunk_size
|
||||
)
|
||||
), 3, on_activation=on_activation, on_prediction=on_prediction
|
||||
)
|
||||
runner.start()
|
||||
Event().wait() # Wait forever
|
||||
|
||||
|
||||
main = PocketsphinxListenScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import wave
|
||||
from prettyparse import Usage
|
||||
from subprocess import check_output, PIPE
|
||||
|
||||
from precise.pocketsphinx.listener import PocketsphinxListener
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.scripts.test import Stats
|
||||
from precise.train_data import TrainData
|
||||
|
||||
|
||||
class PocketsphinxTestScript(BaseScript):
|
||||
usage = Usage('''
|
||||
Test a dataset using Pocketsphinx
|
||||
|
||||
:key_phrase str
|
||||
Key phrase composed of words from dictionary
|
||||
|
||||
:dict_file str
|
||||
Filename of dictionary with word pronunciations
|
||||
|
||||
:hmm_folder str
|
||||
Folder containing hidden markov model
|
||||
|
||||
:-th --threshold str 1e-90
|
||||
Threshold for activations
|
||||
|
||||
:-t --use-train
|
||||
Evaluate training data instead of test data
|
||||
|
||||
:-nf --no-filenames
|
||||
Don't show the names of files that failed
|
||||
|
||||
...
|
||||
''') | TrainData.usage
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.listener = PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold
|
||||
)
|
||||
|
||||
self.outputs = []
|
||||
self.targets = []
|
||||
self.filenames = []
|
||||
|
||||
def get_stats(self):
|
||||
return Stats(self.outputs, self.targets, self.filenames)
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
|
||||
ww_files, nww_files = data.train_files if args.use_train else data.test_files
|
||||
self.run_test(ww_files, 'Wake Word', 1.0)
|
||||
self.run_test(nww_files, 'Not Wake Word', 0.0)
|
||||
stats = self.get_stats()
|
||||
if not self.args.no_filenames:
|
||||
fp_files = stats.calc_filenames(False, True, 0.5)
|
||||
fn_files = stats.calc_filenames(False, False, 0.5)
|
||||
print('=== False Positives ===')
|
||||
print('\n'.join(fp_files))
|
||||
print()
|
||||
print('=== False Negatives ===')
|
||||
print('\n'.join(fn_files))
|
||||
print()
|
||||
print(stats.counts_str(0.5))
|
||||
print()
|
||||
print(stats.summary_str(0.5))
|
||||
|
||||
def eval_file(self, filename) -> float:
|
||||
transcription = check_output(
|
||||
['pocketsphinx_continuous', '-kws_threshold', '1e-20', '-keyphrase', 'hey my craft',
|
||||
'-infile', filename], stderr=PIPE)
|
||||
return float(bool(transcription) and not transcription.isspace())
|
||||
|
||||
def run_test(self, test_files, label_name, label):
|
||||
print()
|
||||
print('===', label_name, '===')
|
||||
for test_file in test_files:
|
||||
try:
|
||||
with wave.open(test_file) as wf:
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
except (OSError, EOFError):
|
||||
print('?', end='', flush=True)
|
||||
continue
|
||||
|
||||
out = int(self.listener.found_wake_word(frames))
|
||||
self.outputs.append(out)
|
||||
self.targets.append(label)
|
||||
self.filenames.append(test_file)
|
||||
print('!' if out else '.', end='', flush=True)
|
||||
print()
|
||||
|
||||
|
||||
main = PocketsphinxTestScript.run_main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
1
precise_lite/__init__.py
Normal file
1
precise_lite/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
__version__ = '0.4.0a1'
|
@ -15,8 +15,9 @@ import attr
|
||||
from os.path import isfile
|
||||
from typing import *
|
||||
|
||||
from precise.functions import load_keras, false_pos, false_neg, weighted_log_loss, set_loss_bias
|
||||
from precise.params import inject_params, pr
|
||||
from precise_lite.functions import load_keras, false_pos, false_neg, \
|
||||
weighted_log_loss, set_loss_bias
|
||||
from precise_lite.params import inject_params, pr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorflow.keras.models import Sequential
|
||||
@ -51,7 +52,7 @@ def load_precise_model(model_name: str) -> Any:
|
||||
|
||||
def create_model(model_name: Optional[str], params: ModelParams) -> 'Sequential':
|
||||
"""
|
||||
Load or create a precise model
|
||||
Load or create a precise_lite model
|
||||
|
||||
Args:
|
||||
model_name: Name of model
|
@ -18,11 +18,11 @@ from os.path import splitext
|
||||
from typing import *
|
||||
from typing import BinaryIO
|
||||
|
||||
from precise.threshold_decoder import ThresholdDecoder
|
||||
from precise.model import load_precise_model
|
||||
from precise.params import inject_params, pr
|
||||
from precise.util import buffer_to_audio
|
||||
from precise.vectorization import vectorize_raw, add_deltas
|
||||
from precise_lite.threshold_decoder import ThresholdDecoder
|
||||
from precise_lite.model import load_precise_model
|
||||
from precise_lite.params import inject_params, pr
|
||||
from precise_lite.util import buffer_to_audio
|
||||
from precise_lite.vectorization import vectorize_raw, add_deltas
|
||||
|
||||
|
||||
class Runner(metaclass=ABCMeta):
|
@ -23,10 +23,10 @@ import shutil
|
||||
from prettyparse import Usage
|
||||
from random import random
|
||||
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import load_audio
|
||||
from precise.util import save_audio
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.train_data import TrainData
|
||||
from precise_lite.util import load_audio
|
||||
from precise_lite.util import save_audio
|
||||
|
||||
|
||||
class NoiseData:
|
@ -17,9 +17,9 @@ from math import sqrt
|
||||
from os.path import basename, splitext
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise_lite.params import inject_params, save_params
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.stats import Stats
|
||||
|
||||
|
||||
class CalcThresholdScript(BaseScript):
|
||||
@ -31,7 +31,7 @@ class CalcThresholdScript(BaseScript):
|
||||
Either Keras (.net) or TensorFlow (.pb) model to adjust
|
||||
|
||||
:input_file str
|
||||
Input stats file that was outputted from precise-graph
|
||||
Input stats file that was outputted from precise_lite-graph
|
||||
|
||||
:-k --model-key str -
|
||||
Custom model name to use from the stats.json
|
@ -22,7 +22,7 @@ from os.path import isfile
|
||||
from prettyparse import Usage
|
||||
from pyaudio import PyAudio
|
||||
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
|
||||
|
||||
def record_until(p, should_return, args):
|
||||
@ -54,7 +54,7 @@ class CollectScript(BaseScript):
|
||||
EXIT_KEY_CODE = 27
|
||||
|
||||
usage = Usage('''
|
||||
Record audio samples for use with precise
|
||||
Record audio samples for use with precise_lite
|
||||
|
||||
:-w --width int 2
|
||||
Sample width of audio
|
@ -18,7 +18,7 @@ from os.path import split, isfile
|
||||
from prettyparse import Usage
|
||||
from shutil import copyfile
|
||||
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
|
||||
class ConvertScript(BaseScript):
|
||||
usage = Usage('''
|
||||
@ -48,8 +48,8 @@ class ConvertScript(BaseScript):
|
||||
|
||||
import tensorflow as tf # Using tensorflow v2.2
|
||||
from tensorflow import keras as K
|
||||
from precise.model import load_precise_model
|
||||
from precise.functions import weighted_log_loss
|
||||
from precise_lite.model import load_precise_model
|
||||
from precise_lite.functions import weighted_log_loss
|
||||
|
||||
out_dir, filename = split(out_file)
|
||||
out_dir = out_dir or '.'
|
@ -17,9 +17,9 @@ import sys
|
||||
import os
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise import __version__
|
||||
from precise.network_runner import Listener
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise_lite import __version__
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
|
||||
|
||||
def add_audio_pipe_to_parser(parser):
|
@ -16,12 +16,12 @@ import json
|
||||
from os.path import isfile, isdir
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params
|
||||
from precise.pocketsphinx.scripts.test import PocketsphinxTestScript
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise.train_data import TrainData
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.params import inject_params
|
||||
from precise_lite.pocketsphinx.scripts.test import PocketsphinxTestScript
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.stats import Stats
|
||||
from precise_lite.train_data import TrainData
|
||||
|
||||
|
||||
class EvalScript(BaseScript):
|
@ -18,12 +18,12 @@ from os.path import basename, splitext
|
||||
from prettyparse import Usage
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params, pr
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise.threshold_decoder import ThresholdDecoder
|
||||
from precise.train_data import TrainData
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.params import inject_params, pr
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.stats import Stats
|
||||
from precise_lite.threshold_decoder import ThresholdDecoder
|
||||
from precise_lite.train_data import TrainData
|
||||
|
||||
|
||||
def get_thresholds(points=100, power=3) -> list:
|
@ -14,16 +14,16 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from os.path import join
|
||||
from precise_runner import PreciseRunner
|
||||
from precise_runner.runner import ListenerEngine
|
||||
from precise_lite_runner import PreciseRunner
|
||||
from precise_lite_runner.runner import ListenerEngine
|
||||
from prettyparse import Usage
|
||||
from random import randint
|
||||
from shutil import get_terminal_size
|
||||
from threading import Event
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.util import save_audio, buffer_to_audio, activate_notify
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.util import save_audio, buffer_to_audio, activate_notify
|
||||
|
||||
|
||||
class ListenScript(BaseScript):
|
@ -19,11 +19,11 @@ from os.path import join, basename
|
||||
from precise_runner.runner import TriggerDetector
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import pr, inject_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.util import load_audio
|
||||
from precise.vectorization import vectorize_raw
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.params import pr, inject_params
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.util import load_audio
|
||||
from precise_lite.vectorization import vectorize_raw
|
||||
|
||||
|
||||
@attr.s()
|
@ -14,11 +14,11 @@
|
||||
# limitations under the License.
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import inject_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.stats import Stats
|
||||
from precise.train_data import TrainData
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.params import inject_params
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.stats import Stats
|
||||
from precise_lite.train_data import TrainData
|
||||
|
||||
|
||||
class TestScript(BaseScript):
|
@ -18,11 +18,11 @@ from os.path import splitext, isfile
|
||||
from prettyparse import Usage
|
||||
from typing import Any, Tuple
|
||||
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.params import inject_params, save_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import calc_sample_hash
|
||||
from precise_lite.model import create_model, ModelParams
|
||||
from precise_lite.params import inject_params, save_params
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.train_data import TrainData
|
||||
from precise_lite.util import calc_sample_hash
|
||||
|
||||
|
||||
class TrainScript(BaseScript):
|
||||
@ -34,7 +34,7 @@ class TrainScript(BaseScript):
|
||||
|
||||
:-sf --samples-file str -
|
||||
Loads subset of data from the provided json file
|
||||
generated with precise-train-sampled
|
||||
generated with precise_lite-train-sampled
|
||||
|
||||
:-is --invert-samples
|
||||
Loads subset of data not inside --samples-file
|
@ -25,12 +25,12 @@ from prettyparse import Usage
|
||||
from random import random, shuffle
|
||||
from typing import *
|
||||
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.network_runner import Listener
|
||||
from precise.params import pr, save_params
|
||||
from precise.scripts.base_script import BaseScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import load_audio, glob_all, save_audio, chunk_audio
|
||||
from precise_lite.model import create_model, ModelParams
|
||||
from precise_lite.network_runner import Listener
|
||||
from precise_lite.params import pr, save_params
|
||||
from precise_lite.scripts.base_script import BaseScript
|
||||
from precise_lite.train_data import TrainData
|
||||
from precise_lite.util import load_audio, glob_all, save_audio, chunk_audio
|
||||
|
||||
|
||||
class TrainGeneratedScript(BaseScript):
|
@ -19,12 +19,12 @@ from prettyparse import Usage
|
||||
from random import random
|
||||
from typing import *
|
||||
|
||||
from precise.model import create_model, ModelParams
|
||||
from precise.network_runner import Listener, KerasRunner
|
||||
from precise.params import pr
|
||||
from precise.scripts.train import TrainScript
|
||||
from precise.train_data import TrainData
|
||||
from precise.util import load_audio, save_audio, glob_all, chunk_audio
|
||||
from precise_lite.model import create_model, ModelParams
|
||||
from precise_lite.network_runner import Listener, KerasRunner
|
||||
from precise_lite.params import pr
|
||||
from precise_lite.scripts.train import TrainScript
|
||||
from precise_lite.train_data import TrainData
|
||||
from precise_lite.util import load_audio, save_audio, glob_all, chunk_audio
|
||||
|
||||
|
||||
def load_trained_fns(model_name: str) -> list:
|
@ -22,9 +22,9 @@ from prettyparse import Usage
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
|
||||
from precise.model import ModelParams, create_model
|
||||
from precise.scripts.train import TrainScript
|
||||
from precise.train_data import TrainData
|
||||
from precise_lite.model import ModelParams, create_model
|
||||
from precise_lite.scripts.train import TrainScript
|
||||
from precise_lite.train_data import TrainData
|
||||
|
||||
|
||||
class TrainOptimizeScript(TrainScript):
|
@ -17,8 +17,8 @@ from itertools import islice
|
||||
from fitipy import Fitipy
|
||||
from prettyparse import Usage
|
||||
|
||||
from precise.scripts.train import TrainScript
|
||||
from precise.util import calc_sample_hash
|
||||
from precise_lite.scripts.train import TrainScript
|
||||
from precise_lite.util import calc_sample_hash
|
||||
|
||||
|
||||
class TrainSampledScript(TrainScript):
|
@ -14,7 +14,7 @@
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from precise.functions import asigmoid, sigmoid, pdf
|
||||
from precise_lite.functions import asigmoid, sigmoid, pdf
|
||||
|
||||
|
||||
class ThresholdDecoder:
|
@ -20,8 +20,8 @@ from prettyparse import Usage
|
||||
from pyache import Pyache
|
||||
from typing import *
|
||||
|
||||
from precise.util import find_wavs, load_audio
|
||||
from precise.vectorization import vectorize_delta, vectorize
|
||||
from precise_lite.util import find_wavs, load_audio
|
||||
from precise_lite.vectorization import vectorize_delta, vectorize
|
||||
|
||||
|
||||
class TrainData:
|
||||
@ -140,7 +140,7 @@ class TrainData:
|
||||
"""Generate data with inhibitory inputs created from wake word samples"""
|
||||
|
||||
def loader(kws: list, nkws: list):
|
||||
from precise.params import pr
|
||||
from precise_lite.params import pr
|
||||
inputs = np.empty((0, pr.n_features, pr.feature_size))
|
||||
outputs = np.zeros((len(kws), 1))
|
||||
for f in kws:
|
||||
@ -182,7 +182,7 @@ class TrainData:
|
||||
|
||||
@staticmethod
|
||||
def __load_files(kw_files: list, nkw_files: list, vectorizer: Callable = None, shuffle=True) -> tuple:
|
||||
from precise.params import pr
|
||||
from precise_lite.params import pr
|
||||
|
||||
input_parts = []
|
||||
output_parts = []
|
||||
@ -213,7 +213,7 @@ class TrainData:
|
||||
print('Loading not-wake-word...')
|
||||
add(nkw_files, 0.0)
|
||||
|
||||
from precise.params import pr
|
||||
from precise_lite.params import pr
|
||||
inputs = np.concatenate(input_parts) if input_parts else np.empty((0, pr.n_features, pr.feature_size))
|
||||
outputs = np.concatenate(output_parts) if output_parts else np.empty((0, 1))
|
||||
|
@ -16,7 +16,7 @@ import numpy as np
|
||||
from os.path import join, dirname, abspath
|
||||
from typing import *
|
||||
|
||||
from precise.params import pr
|
||||
from precise_lite.params import pr
|
||||
|
||||
|
||||
class InvalidAudio(ValueError):
|
@ -16,8 +16,8 @@ import numpy as np
|
||||
import os
|
||||
from typing import *
|
||||
|
||||
from precise.params import pr, Vectorizer
|
||||
from precise.util import load_audio, InvalidAudio
|
||||
from precise_lite.params import pr, Vectorizer
|
||||
from precise_lite.util import load_audio, InvalidAudio
|
||||
from sonopy import mfcc_spec, mel_spec
|
||||
|
||||
inhibit_t = 0.4
|
103
publish.sh
103
publish.sh
@ -1,103 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Usage: upload_file FILE REMOTE_PATH
|
||||
upload_s3() {
|
||||
file="$1"
|
||||
remote_url="s3://$2"
|
||||
eval cfg_file="~/.s3cfg.mycroft-artifact-writer"
|
||||
[ -f "$cfg_file" ] && s3cmd put $1 $remote_url --acl-public -c ~/.s3cfg.mycroft-artifact-writer || echo "Could not find $cfg_file. Skipping upload."
|
||||
}
|
||||
|
||||
# Usage: upload_git FILE GIT_FOLDER
|
||||
upload_git() {
|
||||
[ -d 'precise-data' ] || git clone git@github.com:MycroftAI/precise-data.git
|
||||
cd precise-data
|
||||
git fetch
|
||||
git checkout origin/dist
|
||||
mv ../$1 $2
|
||||
git add $2
|
||||
git commit --amend --no-edit
|
||||
git push --force origin HEAD:dist
|
||||
cd ..
|
||||
}
|
||||
|
||||
# Usage: find_type stable|unstable
|
||||
find_type() {
|
||||
[ "$1" = "stable" ] && echo "release" || echo "daily"
|
||||
}
|
||||
|
||||
# Usage: find_version stable|unstable
|
||||
find_version() {
|
||||
[ "$1" = "stable" ] && git describe --abbrev=0 || date +%s
|
||||
}
|
||||
|
||||
find_arch() {
|
||||
python3 -c 'import platform; print(platform.machine())'
|
||||
}
|
||||
|
||||
# Usage: show_usage $0
|
||||
show_usage() {
|
||||
echo "Usage: $1 stable|unstable [git|s3]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Usage: parse_args "$@"
|
||||
parse_args() {
|
||||
build_type="error"
|
||||
upload_type="s3"
|
||||
|
||||
while [ $# -gt 0 ]; do
|
||||
case "$1" in
|
||||
stable|unstable)
|
||||
build_type="$1";;
|
||||
git|s3)
|
||||
upload_type="$1";;
|
||||
*)
|
||||
show_usage
|
||||
esac
|
||||
shift
|
||||
done
|
||||
[ "$build_type" != "error" ] || show_usage
|
||||
}
|
||||
|
||||
set -e
|
||||
|
||||
parse_args "$@"
|
||||
|
||||
type="$(find_type $build_type)"
|
||||
version="$(find_version $build_type)"
|
||||
arch="$(find_arch)"
|
||||
|
||||
.venv/bin/pip3 install pyinstaller
|
||||
rm -rf dist/
|
||||
echo "Building executable..."
|
||||
.venv/bin/pyinstaller -y precise.engine.spec
|
||||
|
||||
out_file=dist/precise-engine.tar.gz
|
||||
cd dist
|
||||
tar -czvf "precise-engine.tar.gz" precise-engine
|
||||
cd -
|
||||
|
||||
echo $version > latest
|
||||
|
||||
if [ "$upload_type" = "git" ]; then
|
||||
upload_git "$out_file" $arch/
|
||||
else
|
||||
upload_s3 "$out_file" bootstrap.mycroft.ai/artifacts/static/$type/$arch/$version/
|
||||
upload_s3 "$out_file" bootstrap.mycroft.ai/artifacts/static/$type/$arch/ # Replace latest version
|
||||
upload_s3 latest bootstrap.mycroft.ai/artifacts/static/$type/$arch/
|
||||
fi
|
||||
|
@ -15,12 +15,12 @@ kiwisolver==1.0.1
|
||||
Markdown==3.1
|
||||
matplotlib==3.0.3
|
||||
mock==2.0.0
|
||||
-e git+git@github.com:MycroftAI/mycroft-precise@37ef1ab91eeca81fd889bce2967775b2f6918d97#egg=mycroft_precise
|
||||
|
||||
numpy==1.16.2
|
||||
pbr==5.1.3
|
||||
pocketsphinx==0.1.15
|
||||
portalocker==1.4.0
|
||||
-e git+git@github.com:MycroftAI/mycroft-precise@37ef1ab91eeca81fd889bce2967775b2f6918d97#egg=precise_runner&subdirectory=runner
|
||||
|
||||
prettyparse==0.1.4
|
||||
protobuf==3.7.1
|
||||
PyAudio==0.2.11
|
||||
|
@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from argparse import ArgumentParser
|
||||
from precise.util import activate_notify
|
||||
from precise_runner import PreciseRunner, PreciseEngine
|
||||
from precise_lite.util import activate_notify
|
||||
from precise_lite_runner import PreciseRunner, PreciseEngine
|
||||
from threading import Event
|
||||
|
||||
|
||||
|
@ -13,10 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from setuptools import setup, find_packages
|
||||
from precise_runner import __version__
|
||||
from precise_lite_runner import __version__
|
||||
|
||||
setup(
|
||||
name='precise-runner',
|
||||
name='precise_lite_runner',
|
||||
version=__version__,
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
@ -25,10 +25,9 @@ setup(
|
||||
|
||||
author='Matthew Scholefield',
|
||||
author_email='matthew.scholefield@mycroft.ai',
|
||||
description='Wrapper to use Mycroft Precise Wake Word Listener',
|
||||
description='Wrapper to use OVOS Precise-lite Wake Word Listener',
|
||||
keywords='wakeword keyword wake word listener sound',
|
||||
url='http://github.com/MycroftAI/mycroft-precise',
|
||||
|
||||
url='https://github.com/OpenVoiceOS/precise-lite',
|
||||
zip_safe=True,
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
|
50
setup.py
50
setup.py
@ -14,18 +14,16 @@
|
||||
# limitations under the License.
|
||||
from setuptools import setup
|
||||
|
||||
from precise import __version__
|
||||
from precise_lite import __version__
|
||||
|
||||
setup(
|
||||
name='mycroft-precise',
|
||||
name='precise_lite',
|
||||
version=__version__,
|
||||
license='Apache-2.0',
|
||||
author='Matthew Scholefield',
|
||||
author_email='matthew.scholefield@mycroft.ai',
|
||||
description='Mycroft Precise Wake Word Listener',
|
||||
long_description='View more info at `the GitHub page '
|
||||
'<https://github.com/mycroftai/mycroft-precise#mycroft-precise>`_',
|
||||
url='http://github.com/MycroftAI/mycroft-precise',
|
||||
description='Mycroft Precise Wake Word Listener, Lite version (OpenVoiceOS)',
|
||||
url='https://github.com/OpenVoiceOS/precise-lite',
|
||||
keywords='wakeword keyword wake word listener sound',
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
@ -43,30 +41,26 @@ setup(
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
packages=[
|
||||
'precise',
|
||||
'precise.scripts',
|
||||
'precise.pocketsphinx',
|
||||
'precise.pocketsphinx.scripts'
|
||||
'precise_lite',
|
||||
'precise_lite.scripts'
|
||||
],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'precise-add-noise=precise.scripts.add_noise:main',
|
||||
'precise-collect=precise.scripts.collect:main',
|
||||
'precise-convert=precise.scripts.convert:main',
|
||||
'precise-eval=precise.scripts.eval:main',
|
||||
'precise-listen=precise.scripts.listen:main',
|
||||
'precise-listen-pocketsphinx=precise.pocketsphinx.scripts.listen:main',
|
||||
'precise-engine=precise.scripts.engine:main',
|
||||
'precise-simulate=precise.scripts.simulate:main',
|
||||
'precise-test=precise.scripts.test:main',
|
||||
'precise-graph=precise.scripts.graph:main',
|
||||
'precise-test-pocketsphinx=precise.pocketsphinx.scripts.test:main',
|
||||
'precise-train=precise.scripts.train:main',
|
||||
'precise-train-optimize=precise.scripts.train_optimize:main',
|
||||
'precise-train-sampled=precise.scripts.train_sampled:main',
|
||||
'precise-train-incremental=precise.scripts.train_incremental:main',
|
||||
'precise-train-generated=precise.scripts.train_generated:main',
|
||||
'precise-calc-threshold=precise.scripts.calc_threshold:main',
|
||||
'precise-lite-add-noise=precise_lite.scripts.add_noise:main',
|
||||
'precise-lite-collect=precise_lite.scripts.collect:main',
|
||||
'precise-lite-convert=precise_lite.scripts.convert:main',
|
||||
'precise-lite-eval=precise_lite.scripts.eval:main',
|
||||
'precise-lite-listen=precise_lite.scripts.listen:main',
|
||||
'precise-lite-engine=precise_lite.scripts.engine:main',
|
||||
'precise-lite-simulate=precise_lite.scripts.simulate:main',
|
||||
'precise-lite-test=precise_lite.scripts.test:main',
|
||||
'precise-lite-graph=precise_lite.scripts.graph:main',
|
||||
'precise-lite-train=precise_lite.scripts.train:main',
|
||||
'precise-lite-train-optimize=precise_lite.scripts.train_optimize:main',
|
||||
'precise-lite-train-sampled=precise_lite.scripts.train_sampled:main',
|
||||
'precise-lite-train-incremental=precise_lite.scripts.train_incremental:main',
|
||||
'precise-lite-train-generated=precise_lite.scripts.train_generated:main',
|
||||
'precise-lite-calc-threshold=precise_lite.scripts.calc_threshold:main',
|
||||
]
|
||||
},
|
||||
install_requires=[
|
||||
@ -78,7 +72,7 @@ setup(
|
||||
'wavio',
|
||||
'typing',
|
||||
'prettyparse>=1.1.0',
|
||||
'precise-runner',
|
||||
'precise_lite_runner',
|
||||
'attrs',
|
||||
'fitipy<1.0',
|
||||
'speechpy-fast',
|
||||
|
@ -1,32 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from precise.scripts.train import TrainScript
|
||||
from test.scripts.test_train import DummyTrainFolder
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def train_folder():
|
||||
folder = DummyTrainFolder(10)
|
||||
try:
|
||||
yield folder
|
||||
finally:
|
||||
folder.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def train_script(train_folder):
|
||||
return TrainScript.create(model=train_folder.model, folder=train_folder.root, epochs=1)
|
@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import atexit
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from os import makedirs
|
||||
from os.path import isdir, join
|
||||
from shutil import rmtree
|
||||
from tempfile import mkdtemp
|
||||
|
||||
from precise.params import pr
|
||||
from precise.util import save_audio
|
||||
|
||||
|
||||
class DummyAudioFolder:
|
||||
def __init__(self, count=10):
|
||||
self.count = count
|
||||
self.root = mkdtemp()
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def rand(self, min, max):
|
||||
return min + (max - min) * np.random.random() * pr.buffer_t
|
||||
|
||||
def generate_samples(self, folder, name, value, duration):
|
||||
for i in range(self.count):
|
||||
save_audio(join(folder, name.format(i)), np.array([value] * int(duration * pr.sample_rate)))
|
||||
|
||||
def subdir(self, *parts):
|
||||
folder = self.path(*parts)
|
||||
if not isdir(folder):
|
||||
makedirs(folder)
|
||||
return folder
|
||||
|
||||
def path(self, *path):
|
||||
return join(self.root, *path)
|
||||
|
||||
def count_files(self, folder):
|
||||
return sum([len(files) for r, d, files in os.walk(folder)])
|
||||
|
||||
def cleanup(self):
|
||||
if isdir(self.root):
|
||||
rmtree(self.root)
|
@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from precise.scripts.add_noise import AddNoiseScript
|
||||
|
||||
from test.scripts.dummy_audio_folder import DummyAudioFolder
|
||||
|
||||
|
||||
class DummyNoiseFolder(DummyAudioFolder):
|
||||
def __init__(self, count=10):
|
||||
super().__init__(count)
|
||||
self.source = self.subdir('source')
|
||||
self.noise = self.subdir('noise')
|
||||
self.output = self.subdir('output')
|
||||
|
||||
self.generate_samples(self.subdir('source', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2))
|
||||
self.generate_samples(self.subdir('source', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2))
|
||||
self.generate_samples(self.noise, 'noise-{}.wav', 0.5, self.rand(10, 20))
|
||||
|
||||
|
||||
class TestAddNoise:
|
||||
def get_base_data(self, count):
|
||||
folders = DummyNoiseFolder(count)
|
||||
base_args = dict(
|
||||
folder=folders.source, noise_folder=folders.noise,
|
||||
output_folder=folders.output
|
||||
)
|
||||
return folders, base_args
|
||||
|
||||
def test_run_basic(self):
|
||||
folders, base_args = self.get_base_data(10)
|
||||
script = AddNoiseScript.create(inflation_factor=1, **base_args)
|
||||
script.run()
|
||||
assert folders.count_files(folders.output) == 20
|
||||
|
||||
def test_run_basic_2(self):
|
||||
folders, base_args = self.get_base_data(10)
|
||||
script = AddNoiseScript.create(inflation_factor=2, **base_args)
|
||||
script.run()
|
||||
assert folders.count_files(folders.output) == 40
|
@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
from os.path import isfile
|
||||
|
||||
from precise.scripts.calc_threshold import CalcThresholdScript
|
||||
from precise.scripts.eval import EvalScript
|
||||
from precise.scripts.graph import GraphScript
|
||||
|
||||
|
||||
def read_content(filename):
|
||||
with open(filename) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def test_combined(train_folder, train_script):
|
||||
train_script.run()
|
||||
params_file = train_folder.model + '.params'
|
||||
assert isfile(train_folder.model)
|
||||
assert isfile(params_file)
|
||||
|
||||
EvalScript.create(folder=train_folder.root, models=[train_folder.model]).run()
|
||||
|
||||
out_file = train_folder.path('outputs.npz')
|
||||
graph_script = GraphScript.create(folder=train_folder.root, models=[train_folder.model], output_file=out_file)
|
||||
graph_script.run()
|
||||
assert isfile(out_file)
|
||||
|
||||
params_before = read_content(params_file)
|
||||
CalcThresholdScript.create(folder=train_folder.root, model=train_folder.model, input_file=out_file).run()
|
||||
assert params_before != read_content(params_file)
|
@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from os.path import isfile
|
||||
|
||||
from precise.scripts.convert import ConvertScript
|
||||
|
||||
|
||||
def test_convert(train_folder, train_script):
|
||||
train_script.run()
|
||||
|
||||
ConvertScript.create(model=train_folder.model, out=train_folder.model + '.pb').run()
|
||||
assert isfile(train_folder.model + '.pb')
|
@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import glob
|
||||
import re
|
||||
from os.path import join
|
||||
|
||||
from precise.scripts.engine import EngineScript
|
||||
from runner.precise_runner import ReadWriteStream
|
||||
|
||||
|
||||
class FakeStdin:
|
||||
def __init__(self, data: bytes):
|
||||
self.buffer = ReadWriteStream(data)
|
||||
|
||||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
class FakeStdout:
|
||||
def __init__(self):
|
||||
self.buffer = ReadWriteStream()
|
||||
|
||||
|
||||
def test_engine(train_folder, train_script):
|
||||
train_script.run()
|
||||
with open(glob.glob(join(train_folder.root, 'wake-word', '*.wav'))[0], 'rb') as f:
|
||||
data = f.read()
|
||||
try:
|
||||
sys.stdin = FakeStdin(data)
|
||||
sys.stdout = FakeStdout()
|
||||
EngineScript.create(model_name=train_folder.model).run()
|
||||
assert re.match(rb'[01]\.[0-9]+', sys.stdout.buffer.buffer)
|
||||
finally:
|
||||
sys.stdin = sys.__stdin__
|
||||
sys.stdout = sys.__stdout__
|
@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from os.path import isfile
|
||||
|
||||
from precise.params import pr
|
||||
from precise.scripts.train import TrainScript
|
||||
from test.scripts.dummy_audio_folder import DummyAudioFolder
|
||||
|
||||
|
||||
class DummyTrainFolder(DummyAudioFolder):
|
||||
def __init__(self, count=10):
|
||||
super().__init__(count)
|
||||
self.generate_samples(self.subdir('wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.generate_samples(self.subdir('not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.generate_samples(self.subdir('test', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.generate_samples(self.subdir('test', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.model = self.path('model.net')
|
||||
|
||||
|
||||
class TestTrain:
|
||||
def test_run_basic(self):
|
||||
folders = DummyTrainFolder(10)
|
||||
script = TrainScript.create(model=folders.model, folder=folders.root)
|
||||
script.run()
|
||||
assert isfile(folders.model)
|
Loading…
Reference in New Issue
Block a user