Speed up creation of visualizer html page for TensorFlow Lite.
Use the NumPy functionality of the object based flatbuffer API. This speeds up a model that took 15 minutes to visualize. PiperOrigin-RevId: 313255207 Change-Id: Ic9d43cbd97c6d5026d903ee947a0a56a0732f150
This commit is contained in:
parent
15bf2a7e76
commit
bb34d65cd7
@ -17,7 +17,10 @@ py_binary(
|
||||
srcs = ["visualize.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow/lite/python:schema_py"],
|
||||
deps = [
|
||||
"//tensorflow/lite/python:schema_py",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
@ -28,6 +28,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||
|
||||
@ -377,23 +378,34 @@ def CamelCaseToSnakeCase(camel_case_input):
|
||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
||||
|
||||
|
||||
def FlatbufferToDict(fb):
|
||||
"""Converts a hierarchy of FB objects into a nested dict."""
|
||||
if hasattr(fb, "__dict__"):
|
||||
def FlatbufferToDict(fb, preserve_as_numpy):
|
||||
"""Converts a hierarchy of FB objects into a nested dict.
|
||||
|
||||
We avoid transforming big parts of the flat buffer into python arrays. This
|
||||
speeds conversion from ten minutes to a few seconds on big graphs.
|
||||
|
||||
Args:
|
||||
fb: a flat buffer structure. (i.e. ModelT)
|
||||
preserve_as_numpy: true if all downstream np.arrays should be preserved.
|
||||
false if all downstream np.array should become python arrays
|
||||
Returns:
|
||||
A dictionary representing the flatbuffer rather than a flatbuffer object.
|
||||
"""
|
||||
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
|
||||
return fb
|
||||
elif hasattr(fb, "__dict__"):
|
||||
result = {}
|
||||
for attribute_name in dir(fb):
|
||||
attribute = fb.__getattribute__(attribute_name)
|
||||
if not callable(attribute) and attribute_name[0] != "_":
|
||||
snake_name = CamelCaseToSnakeCase(attribute_name)
|
||||
result[snake_name] = FlatbufferToDict(attribute)
|
||||
preserve = True if attribute_name == "buffers" else preserve_as_numpy
|
||||
result[snake_name] = FlatbufferToDict(attribute, preserve)
|
||||
return result
|
||||
elif isinstance(fb, str):
|
||||
return fb
|
||||
elif isinstance(fb, np.ndarray):
|
||||
return fb if preserve_as_numpy else fb.tolist()
|
||||
elif hasattr(fb, "__len__"):
|
||||
result = []
|
||||
for entry in fb:
|
||||
result.append(FlatbufferToDict(entry))
|
||||
return result
|
||||
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
|
||||
else:
|
||||
return fb
|
||||
|
||||
@ -401,7 +413,7 @@ def FlatbufferToDict(fb):
|
||||
def CreateDictFromFlatbuffer(buffer_data):
|
||||
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
|
||||
model = schema_fb.ModelT.InitFromObj(model_obj)
|
||||
return FlatbufferToDict(model)
|
||||
return FlatbufferToDict(model, preserve_as_numpy=False)
|
||||
|
||||
|
||||
def CreateHtmlFile(tflite_input, html_output):
|
||||
|
Loading…
Reference in New Issue
Block a user