Speed up creation of visualizer html page for TensorFlow Lite.
Use the NumPy functionality of the object based flatbuffer API. This speeds up a model that took 15 minutes to visualize. PiperOrigin-RevId: 313255207 Change-Id: Ic9d43cbd97c6d5026d903ee947a0a56a0732f150
This commit is contained in:
parent
15bf2a7e76
commit
bb34d65cd7
@ -17,7 +17,10 @@ py_binary(
|
|||||||
srcs = ["visualize.py"],
|
srcs = ["visualize.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = ["//tensorflow/lite/python:schema_py"],
|
deps = [
|
||||||
|
"//tensorflow/lite/python:schema_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -28,6 +28,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.lite.python import schema_py_generated as schema_fb
|
from tensorflow.lite.python import schema_py_generated as schema_fb
|
||||||
|
|
||||||
@ -377,23 +378,34 @@ def CamelCaseToSnakeCase(camel_case_input):
|
|||||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
||||||
|
|
||||||
|
|
||||||
def FlatbufferToDict(fb):
|
def FlatbufferToDict(fb, preserve_as_numpy):
|
||||||
"""Converts a hierarchy of FB objects into a nested dict."""
|
"""Converts a hierarchy of FB objects into a nested dict.
|
||||||
if hasattr(fb, "__dict__"):
|
|
||||||
|
We avoid transforming big parts of the flat buffer into python arrays. This
|
||||||
|
speeds conversion from ten minutes to a few seconds on big graphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fb: a flat buffer structure. (i.e. ModelT)
|
||||||
|
preserve_as_numpy: true if all downstream np.arrays should be preserved.
|
||||||
|
false if all downstream np.array should become python arrays
|
||||||
|
Returns:
|
||||||
|
A dictionary representing the flatbuffer rather than a flatbuffer object.
|
||||||
|
"""
|
||||||
|
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
|
||||||
|
return fb
|
||||||
|
elif hasattr(fb, "__dict__"):
|
||||||
result = {}
|
result = {}
|
||||||
for attribute_name in dir(fb):
|
for attribute_name in dir(fb):
|
||||||
attribute = fb.__getattribute__(attribute_name)
|
attribute = fb.__getattribute__(attribute_name)
|
||||||
if not callable(attribute) and attribute_name[0] != "_":
|
if not callable(attribute) and attribute_name[0] != "_":
|
||||||
snake_name = CamelCaseToSnakeCase(attribute_name)
|
snake_name = CamelCaseToSnakeCase(attribute_name)
|
||||||
result[snake_name] = FlatbufferToDict(attribute)
|
preserve = True if attribute_name == "buffers" else preserve_as_numpy
|
||||||
|
result[snake_name] = FlatbufferToDict(attribute, preserve)
|
||||||
return result
|
return result
|
||||||
elif isinstance(fb, str):
|
elif isinstance(fb, np.ndarray):
|
||||||
return fb
|
return fb if preserve_as_numpy else fb.tolist()
|
||||||
elif hasattr(fb, "__len__"):
|
elif hasattr(fb, "__len__"):
|
||||||
result = []
|
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
|
||||||
for entry in fb:
|
|
||||||
result.append(FlatbufferToDict(entry))
|
|
||||||
return result
|
|
||||||
else:
|
else:
|
||||||
return fb
|
return fb
|
||||||
|
|
||||||
@ -401,7 +413,7 @@ def FlatbufferToDict(fb):
|
|||||||
def CreateDictFromFlatbuffer(buffer_data):
|
def CreateDictFromFlatbuffer(buffer_data):
|
||||||
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
|
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
|
||||||
model = schema_fb.ModelT.InitFromObj(model_obj)
|
model = schema_fb.ModelT.InitFromObj(model_obj)
|
||||||
return FlatbufferToDict(model)
|
return FlatbufferToDict(model, preserve_as_numpy=False)
|
||||||
|
|
||||||
|
|
||||||
def CreateHtmlFile(tflite_input, html_output):
|
def CreateHtmlFile(tflite_input, html_output):
|
||||||
|
Loading…
Reference in New Issue
Block a user