precise-lite-amd64aarch64/precise_lite/scripts/convert.py

77 lines
2.6 KiB
Python

#!/usr/bin/env python3
# Attribution: This script was adapted from https://github.com/amir-abdi/keras_to_tensorflow
# Copyright 2019 Mycroft AI Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import split, isfile
from prettyparse import Usage
from shutil import copyfile
from precise_lite.scripts.base_script import BaseScript
class ConvertScript(BaseScript):
usage = Usage('''
Convert wake word model from Keras to TensorFlow
:model str
Input Keras model (.net)
:-o --out str {model}.tflite
Custom output TensorFlow Lite filename
''')
def run(self):
args = self.args
model_name = args.model.replace('.net', '')
self.convert(args.model, args.out.format(model=model_name))
def convert(self, model_path: str, out_file: str):
"""
Converts an HD5F file from Keras to a .tflite for use with TensorFlow Runtime
Args:
model_path: location of Keras model
out_file: location to write TFLite model
"""
print('Converting', model_path, 'to', out_file, '...')
import tensorflow as tf # Using tensorflow v2.2
from tensorflow import keras as K
from precise_lite.model import load_precise_model
from precise_lite.functions import weighted_log_loss
out_dir, filename = split(out_file)
out_dir = out_dir or '.'
os.makedirs(out_dir, exist_ok=True)
# Load custom loss function with model
model = K.models.load_model(model_path, custom_objects={'weighted_log_loss': weighted_log_loss})
model.summary()
# Support for freezing Keras models to .pb has been removed in TF 2.0.
# Converting instead to TFLite model
print('Starting TFLite conversion.')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open(out_file, "wb").write(tflite_model)
print('Wrote to ' + out_file)
main = ConvertScript.run_main
if __name__ == '__main__':
main()