From bb34d65cd7e435065967f5089e2b7f12bf619aa6 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Tue, 26 May 2020 13:15:12 -0700 Subject: [PATCH] Speed up creation of visualizer html page for TensorFlow Lite. Use the NumPy functionality of the object based flatbuffer API. This speeds up a model that took 15 minutes to visualize. PiperOrigin-RevId: 313255207 Change-Id: Ic9d43cbd97c6d5026d903ee947a0a56a0732f150 --- tensorflow/lite/tools/BUILD | 5 ++++- tensorflow/lite/tools/visualize.py | 34 ++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index a96c1c3ede3..6ae5c1dda18 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -17,7 +17,10 @@ py_binary( srcs = ["visualize.py"], python_version = "PY3", srcs_version = "PY2AND3", - deps = ["//tensorflow/lite/python:schema_py"], + deps = [ + "//tensorflow/lite/python:schema_py", + "//third_party/py/numpy", + ], ) py_test( diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py index 1f89f9c5448..3d22d1bb05b 100644 --- a/tensorflow/lite/tools/visualize.py +++ b/tensorflow/lite/tools/visualize.py @@ -28,6 +28,7 @@ import json import os import re import sys +import numpy as np from tensorflow.lite.python import schema_py_generated as schema_fb @@ -377,23 +378,34 @@ def CamelCaseToSnakeCase(camel_case_input): return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def FlatbufferToDict(fb): - """Converts a hierarchy of FB objects into a nested dict.""" - if hasattr(fb, "__dict__"): +def FlatbufferToDict(fb, preserve_as_numpy): + """Converts a hierarchy of FB objects into a nested dict. + + We avoid transforming big parts of the flat buffer into python arrays. This + speeds conversion from ten minutes to a few seconds on big graphs. + + Args: + fb: a flat buffer structure. (i.e. ModelT) + preserve_as_numpy: true if all downstream np.arrays should be preserved. + false if all downstream np.array should become python arrays + Returns: + A dictionary representing the flatbuffer rather than a flatbuffer object. + """ + if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): + return fb + elif hasattr(fb, "__dict__"): result = {} for attribute_name in dir(fb): attribute = fb.__getattribute__(attribute_name) if not callable(attribute) and attribute_name[0] != "_": snake_name = CamelCaseToSnakeCase(attribute_name) - result[snake_name] = FlatbufferToDict(attribute) + preserve = True if attribute_name == "buffers" else preserve_as_numpy + result[snake_name] = FlatbufferToDict(attribute, preserve) return result - elif isinstance(fb, str): - return fb + elif isinstance(fb, np.ndarray): + return fb if preserve_as_numpy else fb.tolist() elif hasattr(fb, "__len__"): - result = [] - for entry in fb: - result.append(FlatbufferToDict(entry)) - return result + return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] else: return fb @@ -401,7 +413,7 @@ def FlatbufferToDict(fb): def CreateDictFromFlatbuffer(buffer_data): model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) model = schema_fb.ModelT.InitFromObj(model_obj) - return FlatbufferToDict(model) + return FlatbufferToDict(model, preserve_as_numpy=False) def CreateHtmlFile(tflite_input, html_output):