Speed up creation of visualizer html page for TensorFlow Lite.

Use the NumPy functionality of the object based flatbuffer API.
This speeds up a model that took 15 minutes to visualize.

PiperOrigin-RevId: 313255207
Change-Id: Ic9d43cbd97c6d5026d903ee947a0a56a0732f150
This commit is contained in:
Andrew Selle 2020-05-26 13:15:12 -07:00 committed by TensorFlower Gardener
parent 15bf2a7e76
commit bb34d65cd7
2 changed files with 27 additions and 12 deletions

View File

@ -17,7 +17,10 @@ py_binary(
srcs = ["visualize.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = ["//tensorflow/lite/python:schema_py"],
deps = [
"//tensorflow/lite/python:schema_py",
"//third_party/py/numpy",
],
)
py_test(

View File

@ -28,6 +28,7 @@ import json
import os
import re
import sys
import numpy as np
from tensorflow.lite.python import schema_py_generated as schema_fb
@ -377,23 +378,34 @@ def CamelCaseToSnakeCase(camel_case_input):
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb):
"""Converts a hierarchy of FB objects into a nested dict."""
if hasattr(fb, "__dict__"):
def FlatbufferToDict(fb, preserve_as_numpy):
"""Converts a hierarchy of FB objects into a nested dict.
We avoid transforming big parts of the flat buffer into python arrays. This
speeds conversion from ten minutes to a few seconds on big graphs.
Args:
fb: a flat buffer structure. (i.e. ModelT)
preserve_as_numpy: true if all downstream np.arrays should be preserved.
false if all downstream np.array should become python arrays
Returns:
A dictionary representing the flatbuffer rather than a flatbuffer object.
"""
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
return fb
elif hasattr(fb, "__dict__"):
result = {}
for attribute_name in dir(fb):
attribute = fb.__getattribute__(attribute_name)
if not callable(attribute) and attribute_name[0] != "_":
snake_name = CamelCaseToSnakeCase(attribute_name)
result[snake_name] = FlatbufferToDict(attribute)
preserve = True if attribute_name == "buffers" else preserve_as_numpy
result[snake_name] = FlatbufferToDict(attribute, preserve)
return result
elif isinstance(fb, str):
return fb
elif isinstance(fb, np.ndarray):
return fb if preserve_as_numpy else fb.tolist()
elif hasattr(fb, "__len__"):
result = []
for entry in fb:
result.append(FlatbufferToDict(entry))
return result
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
else:
return fb
@ -401,7 +413,7 @@ def FlatbufferToDict(fb):
def CreateDictFromFlatbuffer(buffer_data):
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
model = schema_fb.ModelT.InitFromObj(model_obj)
return FlatbufferToDict(model)
return FlatbufferToDict(model, preserve_as_numpy=False)
def CreateHtmlFile(tflite_input, html_output):