From 6b77079914c8d9888877d50d48beb62463173374 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Tue, 5 Nov 2019 17:28:13 -0800 Subject: [PATCH] *Add flag 'conversion_summary_dir' to tflite_converter. When user passes this flag and uses the new MLIR converter(via command-line), it will generate conversion logs under the specified folder. PiperOrigin-RevId: 278743450 Change-Id: Ic840a56642629514816582390a267b037b0bbb24 --- .../g3doc/r1/convert/cmdline_reference.md | 7 + tensorflow/lite/python/BUILD | 2 + tensorflow/lite/python/convert.py | 6 +- tensorflow/lite/python/lite.py | 6 +- tensorflow/lite/python/lite_test.py | 38 +++ tensorflow/lite/python/tflite_convert.py | 23 +- tensorflow/lite/python/tflite_convert_test.py | 24 ++ tensorflow/lite/toco/args.h | 1 + tensorflow/lite/toco/logging/BUILD | 97 +++++++ .../lite/toco/logging/conversion_log_util.cc | 241 ++++++++++++++++ .../lite/toco/logging/conversion_log_util.h | 58 ++++ .../toco/logging/conversion_log_util_test.cc | 228 +++++++++++++++ tensorflow/lite/toco/logging/gen_html.py | 259 ++++++++++++++++++ tensorflow/lite/toco/logging/gen_html_test.py | 116 ++++++++ tensorflow/lite/toco/logging/template.html | 163 +++++++++++ tensorflow/lite/toco/logging/testdata/BUILD | 6 + .../lite/toco/logging/testdata/generated.html | 163 +++++++++++ .../toco/logging/testdata/toco_log_after.pb | 14 + .../toco/logging/testdata/toco_log_before.pb | 10 + .../toco/logging/testdata/toco_tf_graph.dot | 1 + .../logging/testdata/toco_tflite_graph.dot | 1 + .../toco/logging/toco_conversion_log.proto | 50 ++++ tensorflow/lite/toco/model_cmdline_flags.cc | 3 + tensorflow/lite/toco/python/BUILD | 4 + .../lite/toco/python/toco_python_api.cc | 48 ++++ tensorflow/lite/toco/toco_flags.proto | 6 +- 26 files changed, 1571 insertions(+), 4 deletions(-) create mode 100644 tensorflow/lite/toco/logging/BUILD create mode 100644 tensorflow/lite/toco/logging/conversion_log_util.cc create mode 100644 tensorflow/lite/toco/logging/conversion_log_util.h create mode 100644 tensorflow/lite/toco/logging/conversion_log_util_test.cc create mode 100644 tensorflow/lite/toco/logging/gen_html.py create mode 100644 tensorflow/lite/toco/logging/gen_html_test.py create mode 100644 tensorflow/lite/toco/logging/template.html create mode 100644 tensorflow/lite/toco/logging/testdata/BUILD create mode 100644 tensorflow/lite/toco/logging/testdata/generated.html create mode 100644 tensorflow/lite/toco/logging/testdata/toco_log_after.pb create mode 100644 tensorflow/lite/toco/logging/testdata/toco_log_before.pb create mode 100644 tensorflow/lite/toco/logging/testdata/toco_tf_graph.dot create mode 100644 tensorflow/lite/toco/logging/testdata/toco_tflite_graph.dot create mode 100644 tensorflow/lite/toco/logging/toco_conversion_log.proto diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md index e1080f5d1f1..8cca69d5963 100644 --- a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md +++ b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md @@ -155,3 +155,10 @@ graph transformations: completed. * `--dump_graphviz_video`. Type: boolean. Outputs GraphViz after every graph transformation. Requires `--dump_graphviz_dir` to be specified. + +The following flag controls generating the conversion logs. The conversion log +includes a protocol buffer of analytics collected during conversion, and an HTML +file where user can preview the conversion summary. + +* `--conversion_summary_dir`. Type: string. Specifies the full path of the + directory to output conversion logs. diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 4a2fc7ba12a..a7f4c3e4804 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -71,6 +71,8 @@ py_library( visibility = ["//visibility:public"], deps = [ ":lite", + "//tensorflow/lite/toco/logging:gen_html", + "//tensorflow/lite/toco/logging:toco_conversion_log_proto_py", "@six_archive//:six", ], ) diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index acd5afa716e..acec1e62867 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -230,7 +230,8 @@ def build_toco_convert_protos(input_tensors, dump_graphviz_video=False, target_ops=None, allow_nonexistent_arrays=False, - debug_info=None): + debug_info=None, + conversion_summary_dir=None): """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -294,6 +295,7 @@ def build_toco_convert_protos(input_tensors, or are unused in the final graph. (default False) debug_info: `GraphDebugInfo` proto containing the stack traces for the original nodes referred by the converted graph. + conversion_summary_dir: A string, the path to the generated conversion logs. Returns: model_flags, toco_flags, debug_info: three protocol buffers describing the @@ -326,6 +328,8 @@ def build_toco_convert_protos(input_tensors, if dump_graphviz_dir: toco.dump_graphviz_dir = dump_graphviz_dir toco.dump_graphviz_include_video = dump_graphviz_video + if conversion_summary_dir: + toco.conversion_summary_dir = conversion_summary_dir if target_ops: if set(target_ops) == set([OpsSet.TFLITE_BUILTINS, OpsSet.SELECT_TF_OPS]): toco.enable_select_tf_ops = True diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index cf5d6d4f31e..7241024f6d9 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -537,6 +537,8 @@ class TFLiteConverter(TFLiteConverterBase): output file. (default None) dump_graphviz_video: Boolean indicating whether to dump the graph after every graph transformation. (default False) + conversion_summary_dir: A string indicating the path to the generated + conversion logs. target_ops: Deprecated. Please specify `target_spec.supported_ops` instead. Set of OpsSet options indicating which converter to use. (default set([OpsSet.TFLITE_BUILTINS])) @@ -621,6 +623,7 @@ class TFLiteConverter(TFLiteConverterBase): self._post_training_quantize = False self.dump_graphviz_dir = None self.dump_graphviz_video = False + self.conversion_summary_dir = None self._debug_info_func = experimental_debug_info_func # Attributes are used by models that cannot be loaded into TensorFlow. @@ -991,7 +994,8 @@ class TFLiteConverter(TFLiteConverterBase): "reorder_across_fake_quant": self.reorder_across_fake_quant, "change_concat_input_ranges": self.change_concat_input_ranges, "dump_graphviz_dir": self.dump_graphviz_dir, - "dump_graphviz_video": self.dump_graphviz_video + "dump_graphviz_video": self.dump_graphviz_video, + "conversion_summary_dir": self.conversion_summary_dir }) # Converts model. diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index c872d887c49..b9a99042a75 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -502,6 +502,44 @@ class FromSessionTest(TestModels, parameterized.TestCase): num_items_graphviz_video = len(os.listdir(graphviz_dir)) self.assertGreater(num_items_graphviz_video, num_items_graphviz) + def testDumpConversionSummary(self): + with ops.Graph().as_default(): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + log_dir = self.get_temp_dir() + converter.conversion_summary_dir = log_dir + # Conversion logs will only be generated when the mlir converter is enabled. + converter.experimental_new_converter = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + num_items_conversion_summary = len(os.listdir(log_dir)) + self.assertTrue(num_items_conversion_summary) + + def testDumpConversionSummaryWithOldConverter(self): + with ops.Graph().as_default(): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TFLiteConverter.from_session(sess, [in_tensor], + [out_tensor]) + log_dir = self.get_temp_dir() + converter.conversion_summary_dir = log_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + # Check nothing is generated under the conversion summary path. + num_items_conversion_summary = len(os.listdir(log_dir)) + self.assertEqual(num_items_conversion_summary, 0) + @parameterized.named_parameters( ('EnableMlirConverter', True), # enable mlir ('DisableMlirConverter', False)) # disable mlir diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index 474249b76e3..06f69315582 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -22,6 +22,7 @@ from __future__ import print_function import argparse import os import sys +import warnings import six from six.moves import zip @@ -29,6 +30,7 @@ from six.moves import zip from tensorflow.lite.python import lite from tensorflow.lite.python import lite_constants from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.lite.toco.logging import gen_html from tensorflow.python import keras from tensorflow.python import tf2 from tensorflow.python.platform import app @@ -198,6 +200,8 @@ def _convert_tf1_model(flags): converter.dump_graphviz_dir = flags.dump_graphviz_dir if flags.dump_graphviz_video: converter.dump_graphviz_vode = flags.dump_graphviz_video + if flags.conversion_summary_dir: + converter.conversion_summary_dir = flags.conversion_summary_dir if flags.experimental_new_converter: converter.experimental_new_converter = True @@ -479,6 +483,13 @@ def _get_tf1_flags(parser): action="store_true", help=("Boolean indicating whether to dump the graph after every graph " "transformation")) + parser.add_argument( + "--conversion_summary_dir", + type=str, + help=("Full filepath to store the conversion logs, which inclues graphviz" + " of the model before/after the conversion, an HTML report and the " + "conversion proto buffers. This will only be generated when passing" + " --experimental_new_converter")) def _get_tf2_flags(parser): @@ -565,7 +576,17 @@ def run_main(_): if use_v2_converter: _convert_tf2_model(tflite_flags) else: - _convert_tf1_model(tflite_flags) + try: + _convert_tf1_model(tflite_flags) + finally: + if tflite_flags.conversion_summary_dir: + if tflite_flags.experimental_new_converter: + gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir, + tflite_flags.post_training_quantize) + else: + warnings.warn( + "Conversion summary will only be generated when enabling" + " the new converter via --experimental_new_converter. ") def main(): diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py index 0338cb1cdaf..018a40f3214 100644 --- a/tensorflow/lite/python/tflite_convert_test.py +++ b/tensorflow/lite/python/tflite_convert_test.py @@ -147,6 +147,30 @@ class TfLiteConvertV1Test(TestModels): self._run(flags_str, should_succeed=True) os.remove(keras_file) + def testConversionSummary(self): + keras_file = self._getKerasModelFile() + log_dir = self.get_temp_dir() + + flags_str = ('--keras_model_file={} --experimental_new_converter ' + '--conversion_summary_dir={}'.format(keras_file, log_dir)) + self._run(flags_str, should_succeed=True) + os.remove(keras_file) + + num_items_conversion_summary = len(os.listdir(log_dir)) + self.assertTrue(num_items_conversion_summary) + + def testConversionSummaryWithOldConverter(self): + keras_file = self._getKerasModelFile() + log_dir = self.get_temp_dir() + + flags_str = ('--keras_model_file={} ' + '--conversion_summary_dir={}'.format(keras_file, log_dir)) + self._run(flags_str, should_succeed=True) + os.remove(keras_file) + + num_items_conversion_summary = len(os.listdir(log_dir)) + self.assertEqual(num_items_conversion_summary, 0) + class TfLiteConvertV2Test(TestModels): diff --git a/tensorflow/lite/toco/args.h b/tensorflow/lite/toco/args.h index 6b6bb78be55..c30ec316128 100644 --- a/tensorflow/lite/toco/args.h +++ b/tensorflow/lite/toco/args.h @@ -146,6 +146,7 @@ struct ParsedModelFlags { Arg graphviz_last_array; Arg dump_graphviz; Arg dump_graphviz_video = Arg(false); + Arg conversion_summary_dir; Arg allow_nonexistent_arrays = Arg(false); Arg allow_nonascii_arrays = Arg(false); Arg arrays_extra_info_file; diff --git a/tensorflow/lite/toco/logging/BUILD b/tensorflow/lite/toco/logging/BUILD new file mode 100644 index 00000000000..a27d5271322 --- /dev/null +++ b/tensorflow/lite/toco/logging/BUILD @@ -0,0 +1,97 @@ +load( + "//tensorflow/core/platform:default/build_config.bzl", + "tf_proto_library_cc", + "tf_proto_library_py", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +tf_proto_library_cc( + name = "toco_conversion_log_proto", + srcs = ["toco_conversion_log.proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_py( + name = "toco_conversion_log_proto", + srcs = ["toco_conversion_log.proto"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "conversion_log_util", + srcs = ["conversion_log_util.cc"], + hdrs = ["conversion_log_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":toco_conversion_log_proto_cc", + "//tensorflow/core:protos_all", + "//tensorflow/lite:version", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:tooling_util", + "//tensorflow/lite/toco/tflite:export", + "//tensorflow/lite/toco/tflite:operator", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +tf_cc_test( + name = "conversion_log_util_test", + srcs = ["conversion_log_util_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":conversion_log_util", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all", + "//tensorflow/lite/toco:model", + "//tensorflow/lite/toco:model_flags_proto", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "html_template", + srcs = [ + "template.html", + ], +) + +py_library( + name = "gen_html", + srcs = ["gen_html.py"], + data = [ + "html_template", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) + +py_test( + name = "gen_html_test", + srcs = ["gen_html_test.py"], + data = [ + "//tensorflow/lite/toco/logging:template.html", + "//tensorflow/lite/toco/logging/testdata:generated.html", + "//tensorflow/lite/toco/logging/testdata:toco_log_after.pb", + "//tensorflow/lite/toco/logging/testdata:toco_log_before.pb", + "//tensorflow/lite/toco/logging/testdata:toco_tf_graph.dot", + "//tensorflow/lite/toco/logging/testdata:toco_tflite_graph.dot", + ], + python_version = "PY2", + srcs_version = "PY2AND3", + deps = [ + ":gen_html", + ":toco_conversion_log_proto_py", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/lite/toco/logging/conversion_log_util.cc b/tensorflow/lite/toco/logging/conversion_log_util.cc new file mode 100644 index 00000000000..86b64e6d3dd --- /dev/null +++ b/tensorflow/lite/toco/logging/conversion_log_util.cc @@ -0,0 +1,241 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/toco/logging/conversion_log_util.h" + +#include + +#include + +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tflite/export.h" +#include "tensorflow/lite/toco/tflite/operator.h" +#include "tensorflow/lite/toco/tooling_util.h" +#include "tensorflow/lite/version.h" + +namespace toco { + +namespace { + +string TryGetOperatorName(const Operator& op) { + string op_name; + if (!op.tensorflow_node_def.empty()) { + // Parse op name from serialized NodeDef. + tensorflow::NodeDef node_def; + if (!node_def.ParseFromString(op.tensorflow_node_def)) { + LOG(ERROR) << "Failed to parse Tensorflow NodeDef"; + } else { + op_name = node_def.op(); + if (!op_name.empty()) return op_name; + } + } + if (op.type == OperatorType::kUnsupported) { + // If we failed to get op name from serialized NodeDef (either because + // the tensorflow_node_def is an empty string, or we failed to parse + // from it), fall back to use 'tensorflow_op' field if this op is a + // TensorflowUnsupportedOperator. + const TensorFlowUnsupportedOperator& unsupported_op = + static_cast(op); + if (!unsupported_op.tensorflow_op.empty()) { + op_name = unsupported_op.tensorflow_op; + return op_name; + } + } + // If this is a built-in op. + op_name = OperatorTypeName(op.type); + return op_name; +} + +string GetOSVersion() { + utsname info; + if (uname(&info)) { + // Failed + LOG(ERROR) << "Cannot get OS info."; + return ""; + } + string os_info = + string(info.sysname) + ";OSVer=" + string(info.release) + ";"; + return os_info; +} + +string ShapeToStringNoSpace(const Shape& shape) { + if (shape.dimensions_count() == 0) { + return "[]"; + } + + return absl::StrCat("[", absl::StrJoin(shape.dims(), ","), "]"); +} + +string GetOperatorSignature( + const Model& model, const Operator& op, + const std::map>& + op_types_map) { + // The signature of an op has the following schema: + // INPUT:SHAPE::TYPE::OUTPUT:SHAPE::TYPE::NAME:VERSION: + string op_signature; + constexpr char delimiter[] = "::"; + + // Get input shapes and types. + op_signature.append("INPUT:"); + for (const auto& input : op.inputs) { + const auto& array = model.GetArray(input); + if (array.has_shape()) { + op_signature.append(ShapeToStringNoSpace(array.shape())); + } else { + op_signature.append("None"); + } + op_signature.append(delimiter); + op_signature.append(ArrayDataTypeName(array.data_type) + delimiter); + } + // Get output shapes and types. + op_signature.append("OUTPUT:"); + for (const auto& output : op.outputs) { + const auto& array = model.GetArray(output); + if (array.has_shape()) { + op_signature.append(ShapeToStringNoSpace(array.shape())); + } else { + op_signature.append("None"); + } + op_signature.append(delimiter); + op_signature.append(ArrayDataTypeName(array.data_type) + delimiter); + } + // Append Op name. + op_signature.append("NAME:"); + op_signature.append(TryGetOperatorName(op) + delimiter); + // Append Op version. + op_signature.append("VERSION:"); + OperatorSignature toco_op_signature; + toco_op_signature.op = &op; + toco_op_signature.model = &model; + if (op_types_map.find(op.type) != op_types_map.end()) { + const int version = op_types_map.at(op.type)->GetVersion(toco_op_signature); + op_signature.append(std::to_string(version)); + } else { + op_signature.append("None"); + } + return op_signature; +} + +} // namespace + +std::vector GetOperatorNames(const Model& model) { + std::vector op_names; + for (const auto& op : model.operators) { + op_names.push_back(TryGetOperatorName(*op)); + } + return op_names; +} + +void CountOperatorsByType(const Model& model, + std::map* built_in_ops, + std::map* custom_ops, + std::map* select_ops) { + for (const auto& op : model.operators) { + OperatorSignature op_signature = {op.get(), &model}; + const auto ops_by_type = + tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/); + tflite::details::OperatorKey op_key(op_signature, ops_by_type, + true /*enable_select_tf_ops*/); + + const string op_name = TryGetOperatorName(*op); + if (op_key.is_custom_op()) { + (*custom_ops)[op_name]++; + } else if (op_key.is_flex_op()) { + (*select_ops)[op_name]++; + } else { + (*built_in_ops)[op_name]++; + } + } +} + +void GetInputAndOutputTypes( + const Model& model, TFLITE_PROTO_NS::RepeatedPtrField* input_types, + TFLITE_PROTO_NS::RepeatedPtrField* output_types) { + for (const auto& input_array : model.flags.input_arrays()) { + const Array& array = model.GetArray(input_array.name()); + input_types->Add(ArrayDataTypeName(array.data_type)); + } + for (const auto& output_array : model.flags.output_arrays()) { + const Array& array = model.GetArray(output_array); + output_types->Add(ArrayDataTypeName(array.data_type)); + } +} + +string GetTfLiteVersion() { return TFLITE_VERSION_STRING; } + +string GetCachedOSVersion() { + static string* version = new string(GetOSVersion()); + return *version; +} + +void GetOpSignatures(const Model& model, + TFLITE_PROTO_NS::RepeatedPtrField* op_signatures) { + const auto& op_types_map = + tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/); + for (const auto& op : model.operators) { + op_signatures->Add(GetOperatorSignature(model, *op, op_types_map)); + } +} + +string GetModelHash(const Model& model) { + // TODO(b/123519920): Implement the hash function for Model. + // Need to consider different implementations for public/private models. + return ""; +} + +void PopulateConversionLog(const Model& model, TocoConversionLog* log) { + // Get the list of ops after conversion. + const std::vector op_names = GetOperatorNames(model); + for (const auto& op_name : op_names) { + log->add_op_list(op_name); + } + + // Get op signatures. + TFLITE_PROTO_NS::RepeatedPtrField op_signatures; + GetOpSignatures(model, &op_signatures); + log->mutable_op_signatures()->CopyFrom(op_signatures); + + // Get op counts by category: custom, built-in or select. + std::map custom_ops, select_ops, built_in_ops; + CountOperatorsByType(model, &built_in_ops, &custom_ops, &select_ops); + log->mutable_custom_ops()->insert(custom_ops.cbegin(), custom_ops.cend()); + log->mutable_built_in_ops()->insert(built_in_ops.cbegin(), + built_in_ops.cend()); + log->mutable_select_ops()->insert(select_ops.cbegin(), select_ops.cend()); + + // Get the model's input and output types. + TFLITE_PROTO_NS::RepeatedPtrField input_types, output_types; + GetInputAndOutputTypes(model, &input_types, &output_types); + log->mutable_input_tensor_types()->CopyFrom(input_types); + log->mutable_output_tensor_types()->CopyFrom(output_types); + + log->set_log_generation_ts(absl::ToUnixMicros(absl::Now())); + + log->set_model_size(model.operators.size()); + log->set_tf_lite_version(GetTfLiteVersion()); + log->set_os_version(GetCachedOSVersion()); + log->set_model_hash(GetModelHash(model)); + // TODO(b/123519920): Populate TOCO error logs. + // Currently we will focus on external installation of TOCO via pip, where + // the C++ TOCO binary is invoked via subprocess command, this will make our + // life easier collecting the error logs emitted by TOCO. However, note that + // if a user directly invokes the C++ TOCO binary, this log might not be + // available. +} + +} // namespace toco diff --git a/tensorflow/lite/toco/logging/conversion_log_util.h b/tensorflow/lite/toco/logging/conversion_log_util.h new file mode 100644 index 00000000000..0cd1a537b08 --- /dev/null +++ b/tensorflow/lite/toco/logging/conversion_log_util.h @@ -0,0 +1,58 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TOCO_LOGGING_CONVERSION_LOG_UTIL_H_ +#define TENSORFLOW_LITE_TOCO_LOGGING_CONVERSION_LOG_UTIL_H_ + +#include +#include + +#include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h" +#include "tensorflow/lite/toco/model.h" + +namespace toco { + +// Populates the TocoConversionLog proto after analyzing the model. +void PopulateConversionLog(const Model& model, TocoConversionLog* log); + +// Returns the names of the operators in the model. +std::vector GetOperatorNames(const Model& model); + +// Counts the number of different types of operators in the model: +// Built-in ops, custom ops and select ops. +// Each map is mapping from the name of the operator (such as 'Conv') to its +// total number of occurences in the model. +void CountOperatorsByType(const Model& model, + std::map* built_in_ops, + std::map* custom_ops, + std::map* select_ops); + +// Gets the input and output types of the model. The input and output is +// specified by model.flags.input_arrays and model.flags.output_arrays. +void GetInputAndOutputTypes( + const Model& model, TFLITE_PROTO_NS::RepeatedPtrField* input_types, + TFLITE_PROTO_NS::RepeatedPtrField* output_types); + +// Calculates signatures for all the ops in the model. An op signature is +// defined by its input/output shapes and types, op name and its version. +void GetOpSignatures(const Model& model, + TFLITE_PROTO_NS::RepeatedPtrField* op_signatures); + +// TODO(b/123519920): Implement this. +// Calculates a unique hash for the model. +string GetModelHash(const Model& model); + +} // namespace toco + +#endif // TENSORFLOW_LITE_TOCO_LOGGING_CONVERSION_LOG_UTIL_H_ diff --git a/tensorflow/lite/toco/logging/conversion_log_util_test.cc b/tensorflow/lite/toco/logging/conversion_log_util_test.cc new file mode 100644 index 00000000000..ac53471cec3 --- /dev/null +++ b/tensorflow/lite/toco/logging/conversion_log_util_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/toco/logging/conversion_log_util.h" + +#include +#include +#include + +#include +#include +#include "absl/memory/memory.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" + +namespace toco { +namespace { + +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +TEST(ConversionLogUtilTest, TestGetOperatorNames) { + Model model; + // Built-in ops. + model.operators.push_back(absl::make_unique()); + model.operators.push_back(absl::make_unique()); + model.operators.push_back(absl::make_unique()); + // Flex ops. + auto avg_pool_3d = absl::make_unique(); + avg_pool_3d->tensorflow_op = "AvgPool3D"; + tensorflow::NodeDef node_def; + node_def.set_op("AvgPool3D"); + node_def.SerializeToString(&avg_pool_3d->tensorflow_node_def); + model.operators.push_back(std::move(avg_pool_3d)); + // Custom ops. + auto my_custom_op = absl::make_unique(); + my_custom_op->tensorflow_op = "MyAwesomeCustomOp"; + model.operators.push_back(std::move(my_custom_op)); + + const auto& output = GetOperatorNames(model); + EXPECT_THAT(output, ElementsAre("Conv", "Mean", "Neg", "AvgPool3D", + "MyAwesomeCustomOp")); +} + +TEST(ConversionLogUtilTest, TestCountOperatorsByType) { + Model model; + // 1st Conv operator. + std::unique_ptr conv1(new ConvOperator()); + const string conv1_input_name = "conv_input1"; + const string conv1_filter_name = "conv_filter1"; + const string conv1_output_name = "conv_output1"; + conv1->inputs.push_back(conv1_input_name); + conv1->inputs.push_back(conv1_filter_name); + conv1->outputs.push_back(conv1_output_name); + auto& array_map = model.GetMutableArrayMap(); + array_map[conv1_input_name] = std::unique_ptr(new Array); + array_map[conv1_filter_name] = std::unique_ptr(new Array); + array_map[conv1_output_name] = std::unique_ptr(new Array); + + // 2nd Conv operator. + std::unique_ptr conv2(new ConvOperator()); + const string conv2_input_name = "conv_input2"; + const string conv2_filter_name = "conv_filter2"; + const string conv2_output_name = "conv_output2"; + conv2->inputs.push_back(conv2_input_name); + conv2->inputs.push_back(conv2_filter_name); + conv2->outputs.push_back(conv2_output_name); + array_map[conv2_input_name] = std::unique_ptr(new Array); + array_map[conv2_filter_name] = std::unique_ptr(new Array); + array_map[conv2_output_name] = std::unique_ptr(new Array); + + // Mean operator. + std::unique_ptr mean(new MeanOperator()); + const string mean_input_name = "mean_input"; + mean->inputs.push_back(mean_input_name); + array_map[mean_input_name] = std::unique_ptr(new Array); + + // 1st flex operator 'AvgPool3D'. + auto avg_pool_3d = absl::make_unique(); + avg_pool_3d->tensorflow_op = "AvgPool3D"; + tensorflow::NodeDef node_def; + node_def.set_op("AvgPool3D"); + node_def.SerializeToString(&avg_pool_3d->tensorflow_node_def); + + // 2nd flex operator 'EluGrad'. + auto elu_grad = absl::make_unique(); + elu_grad->tensorflow_op = "EluGrad"; + node_def.set_op("EluGrad"); + node_def.SerializeToString(&elu_grad->tensorflow_node_def); + + // 1st custom operator 'MyAwesomeCustomOp'. + auto my_custom_op = absl::make_unique(); + my_custom_op->tensorflow_op = "MyAwesomeCustomOp"; + + model.operators.push_back(std::move(conv1)); + model.operators.push_back(std::move(conv2)); + model.operators.push_back(std::move(mean)); + model.operators.push_back(std::move(avg_pool_3d)); + model.operators.push_back(std::move(elu_grad)); + model.operators.push_back(std::move(my_custom_op)); + + std::map built_in_ops, select_ops, custom_ops; + CountOperatorsByType(model, &built_in_ops, &custom_ops, &select_ops); + + EXPECT_THAT(built_in_ops, + UnorderedElementsAre(std::pair("Conv", 2), + std::pair("Mean", 1))); + EXPECT_THAT(select_ops, + UnorderedElementsAre(std::pair("AvgPool3D", 1), + std::pair("EluGrad", 1))); + EXPECT_THAT(custom_ops, UnorderedElementsAre( + std::pair("MyAwesomeCustomOp", 1))); +} + +TEST(ConversionLogUtilTest, TestGetInputAndOutputTypes) { + Model model; + auto& array_map = model.GetMutableArrayMap(); + const string input1 = "conv_input"; + const string input2 = "conv_filter"; + const string input3 = "feature"; + const string output = "softmax"; + array_map[input1] = std::unique_ptr(new Array); + array_map[input1]->data_type = ArrayDataType::kFloat; + array_map[input2] = std::unique_ptr(new Array); + array_map[input2]->data_type = ArrayDataType::kFloat; + array_map[input3] = std::unique_ptr(new Array); + array_map[input3]->data_type = ArrayDataType::kInt16; + array_map[output] = std::unique_ptr(new Array); + array_map[output]->data_type = ArrayDataType::kFloat; + + InputArray input_arrays[3]; + input_arrays[0].set_name(input1); + input_arrays[1].set_name(input2); + input_arrays[2].set_name(input3); + *model.flags.add_input_arrays() = input_arrays[0]; + *model.flags.add_input_arrays() = input_arrays[1]; + *model.flags.add_input_arrays() = input_arrays[2]; + model.flags.add_output_arrays(output); + + TFLITE_PROTO_NS::RepeatedPtrField input_types, output_types; + GetInputAndOutputTypes(model, &input_types, &output_types); + + EXPECT_THAT(input_types, ElementsAre("float", "float", "int16")); + EXPECT_THAT(output_types, ElementsAre("float")); +} + +TEST(ConversionLogUtilTest, TestGetOpSignatures) { + Model model; + auto& array_map = model.GetMutableArrayMap(); + + std::unique_ptr conv(new ConvOperator()); + const string conv_input_name = "conv_input"; + const string conv_filter_name = "conv_filter"; + const string conv_output_name = "conv_output"; + conv->inputs.push_back(conv_input_name); + conv->inputs.push_back(conv_filter_name); + conv->outputs.push_back(conv_output_name); + array_map[conv_input_name] = std::unique_ptr(new Array); + array_map[conv_input_name]->data_type = ArrayDataType::kFloat; + array_map[conv_input_name]->copy_shape({4, 4, 3}); + array_map[conv_filter_name] = std::unique_ptr(new Array); + array_map[conv_filter_name]->data_type = ArrayDataType::kFloat; + array_map[conv_filter_name]->copy_shape({2, 2}); + array_map[conv_output_name] = std::unique_ptr(new Array); + array_map[conv_output_name]->data_type = ArrayDataType::kFloat; + array_map[conv_output_name]->copy_shape({4, 4, 2}); + + const string mean_input_name = "mean_input"; + const string mean_output_name = "mean_output"; + std::unique_ptr mean(new MeanOperator()); + mean->inputs.push_back(mean_input_name); + mean->outputs.push_back(mean_output_name); + array_map[mean_input_name] = std::unique_ptr(new Array); + array_map[mean_output_name] = std::unique_ptr(new Array); + + const string avg_pool_3d_output_name = "avg_pool_output"; + auto avg_pool_3d = absl::make_unique(); + avg_pool_3d->tensorflow_op = "AvgPool3D"; + tensorflow::NodeDef node_def; + node_def.set_op("AvgPool3D"); + node_def.SerializeToString(&avg_pool_3d->tensorflow_node_def); + avg_pool_3d->inputs.push_back(conv_output_name); + avg_pool_3d->outputs.push_back(avg_pool_3d_output_name); + array_map[avg_pool_3d_output_name] = std::unique_ptr(new Array); + array_map[avg_pool_3d_output_name]->data_type = ArrayDataType::kInt32; + array_map[avg_pool_3d_output_name]->copy_shape({2, 2}); + + const string custom_op_output_name = "custom_op_output"; + auto my_custom_op = absl::make_unique(); + my_custom_op->tensorflow_op = "MyAwesomeCustomOp"; + my_custom_op->inputs.push_back(avg_pool_3d_output_name); + my_custom_op->outputs.push_back(custom_op_output_name); + array_map[custom_op_output_name] = std::unique_ptr(new Array); + array_map[custom_op_output_name]->data_type = ArrayDataType::kFloat; + array_map[custom_op_output_name]->copy_shape({3}); + + model.operators.push_back(std::move(conv)); + model.operators.push_back(std::move(mean)); + model.operators.push_back(std::move(avg_pool_3d)); + model.operators.push_back(std::move(my_custom_op)); + + TFLITE_PROTO_NS::RepeatedPtrField op_signatures; + GetOpSignatures(model, &op_signatures); + EXPECT_THAT(op_signatures, + UnorderedElementsAre( + "INPUT:[4,4,3]::float::[2,2]::float::OUTPUT:[4,4,2]::float::" + "NAME:Conv::VERSION:1", + "INPUT:None::None::OUTPUT:None::None::NAME:Mean::VERSION:1", + "INPUT:[4,4,2]::float::OUTPUT:[2,2]::int32::NAME:AvgPool3D::" + "VERSION:1", + "INPUT:[2,2]::int32::OUTPUT:[3]::float::NAME:" + "MyAwesomeCustomOp::VERSION:1")); +} + +} // namespace +} // namespace toco diff --git a/tensorflow/lite/toco/logging/gen_html.py b/tensorflow/lite/toco/logging/gen_html.py new file mode 100644 index 00000000000..33ac3f1a006 --- /dev/null +++ b/tensorflow/lite/toco/logging/gen_html.py @@ -0,0 +1,259 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A utility class to generate the report HTML based on a common template.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import os + +from tensorflow.lite.toco.logging import toco_conversion_log_pb2 as _toco_conversion_log_pb2 +from tensorflow.python.lib.io import file_io as _file_io +from tensorflow.python.platform import resource_loader as _resource_loader + +html_escape_table = { + "&": "&", + '"': """, + "'": "'", + ">": ">", + "<": "<", +} + + +def html_escape(text): + return "".join(html_escape_table.get(c, c) for c in text) + + +def get_input_type_from_signature(op_signature): + """Parses op_signature and returns a string denoting the input tensor type. + + Args: + op_signature: a string specifying the signature of a particular operator. + The signature of an operator contains the input tensor's shape and type, + output tensor's shape and type, operator's name and its version. It has + the following schema: + INPUT:input_1_shape::input_1_type::input_2_shape::input_2_type::.. + ::OUTPUT:output_1_shape::output_1_type::output_2_shape::output_2_type:: + ..::NAME:operator_name ::VERSION:operator_version + An example of an operator signature is: + INPUT:[1,73,73,160]::float::[64,1,1,160]::float::[64]::float:: + OUTPUT:[1,73,73,64]::float::NAME:Conv::VERSION:1 + + Returns: + A string denoting the input tensors' type. In the form of shape/type + separated + by comma. For example: + shape:[1,73,73,160],type:float,shape:[64,1,1,160],type:float,shape:[64], + type:float + """ + start = op_signature.find(":") + end = op_signature.find("::OUTPUT") + inputs = op_signature[start + 1:end] + lst = inputs.split("::") + out_str = "" + for i in range(len(lst)): + if i % 2 == 0: + out_str += "shape:" + else: + out_str += "type:" + out_str += lst[i] + out_str += "," + return out_str[:-1] + + +def get_operator_type(op_name, conversion_log): + if op_name in conversion_log.built_in_ops: + return "BUILT-IN" + elif op_name in conversion_log.custom_ops: + return "CUSTOM OP" + else: + return "SELECT OP" + + +class HTMLGenerator(object): + """Utility class to generate an HTML report.""" + + def __init__(self, html_template_path, export_report_path): + """Reads the HTML template content. + + Args: + html_template_path: A string, path to the template HTML file. + export_report_path: A string, path to the generated HTML report. This path + should point to a '.html' file with date and time in its name. + e.g. 2019-01-01-10:05.toco_report.html. + + Raises: + IOError: File doesn't exist. + """ + # Load the template HTML. + if not _file_io.file_exists(html_template_path): + raise IOError("File '{0}' does not exist.".format(html_template_path)) + with _file_io.FileIO(html_template_path, "r") as f: + self.html_template = f.read() + + _file_io.recursive_create_dir(os.path.dirname(export_report_path)) + self.export_report_path = export_report_path + + def generate(self, + toco_conversion_log_before, + toco_conversion_log_after, + post_training_quant_enabled, + dot_before, + dot_after, + toco_err_log=""): + """Generates the HTML report and writes it to local directory. + + This function uses the fields in `toco_conversion_log_before` and + `toco_conversion_log_after` to populate the HTML content. Certain markers + (placeholders) in the HTML template are then substituted with the fields + from the protos. Once finished it will write the HTML file to the specified + local file path. + + Args: + toco_conversion_log_before: A `TocoConversionLog` protobuf generated + before the model is converted by TOCO. + toco_conversion_log_after: A `TocoConversionLog` protobuf generated after + the model is converted by TOCO. + post_training_quant_enabled: A boolean, whether post-training quantization + is enabled. + dot_before: A string, the dot representation of the model + before the conversion. + dot_after: A string, the dot representation of the model after + the conversion. + toco_err_log: A string, the logs emitted by TOCO during conversion. Caller + need to ensure that this string is properly anoynimized (any kind of + user data should be eliminated). + + Raises: + RuntimeError: When error occurs while generating the template. + """ + html_dict = {} + html_dict[""] = ( + r'Fail' + ) if toco_err_log else r'Success' + html_dict[""] = str( + toco_conversion_log_before.model_size) + html_dict[""] = str( + toco_conversion_log_after.model_size) + html_dict[""] = str( + sum(toco_conversion_log_after.built_in_ops.values())) + html_dict[""] = str( + sum(toco_conversion_log_after.select_ops.values())) + html_dict[""] = str( + sum(toco_conversion_log_after.custom_ops.values())) + html_dict[""] = ( + "is" if post_training_quant_enabled else "isn't") + + pre_op_profile = "" + post_op_profile = "" + + # Generate pre-conversion op profiles as a list of HTML table rows. + for i in range(len(toco_conversion_log_before.op_list)): + # Append operator name column. + pre_op_profile += "" + toco_conversion_log_before.op_list[ + i] + "" + # Append input type column. + if i < len(toco_conversion_log_before.op_signatures): + pre_op_profile += "" + get_input_type_from_signature( + toco_conversion_log_before.op_signatures[i]) + "" + else: + pre_op_profile += "" + + # Generate post-conversion op profiles as a list of HTML table rows. + for op in toco_conversion_log_after.op_list: + supported_type = get_operator_type(op, toco_conversion_log_after) + post_op_profile += ("" + op + "" + supported_type + + "") + + html_dict[""] = pre_op_profile + html_dict[""] = post_op_profile + html_dict[""] = dot_before + html_dict[""] = dot_after + html_dict[""] = html_escape(toco_err_log) + + # Replace each marker (as keys of html_dict) with the actual text (as values + # of html_dict) in the HTML template string. + template = self.html_template + for marker in html_dict: + template = template.replace(marker, html_dict[marker], 1) + # Check that the marker text is replaced. + if template.find(marker) != -1: + raise RuntimeError("Could not populate marker text %r" % marker) + + with _file_io.FileIO(self.export_report_path, "w") as f: + f.write(template) + + +def gen_conversion_log_html(conversion_log_dir, quantization_enabled): + """Generates an HTML report about the conversion process. + + Args: + conversion_log_dir: A string specifying the file directory of the conversion + logs. It's required that before calling this function, the + `conversion_log_dir` + already contains the following files: `toco_log_before.pb`, + `toco_log_after.pb`, `toco_tf_graph.dot`, + `toco_tflite_graph.dot`. + quantization_enabled: A boolean, passed from the tflite converter to + indicate whether post-training quantization is enabled during conversion. + + Raises: + IOError: When any of the required files doesn't exist. + """ + template_filename = _resource_loader.get_path_to_datafile("template.html") + if not os.path.exists(template_filename): + raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format( + template_filename)) + + toco_log_before_path = os.path.join(conversion_log_dir, "toco_log_before.pb") + toco_log_after_path = os.path.join(conversion_log_dir, "toco_log_after.pb") + dot_before_path = os.path.join(conversion_log_dir, "toco_tf_graph.dot") + dot_after_path = os.path.join(conversion_log_dir, "toco_tflite_graph.dot") + if not os.path.exists(toco_log_before_path): + raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format( + toco_log_before_path)) + if not os.path.exists(toco_log_after_path): + raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format( + toco_log_after_path)) + if not os.path.exists(dot_before_path): + raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format( + dot_before_path)) + if not os.path.exists(dot_after_path): + raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format( + dot_after_path)) + + html_generator = HTMLGenerator( + template_filename, + os.path.join(conversion_log_dir, "toco_conversion_summary.html")) + + # Parse the generated `TocoConversionLog`. + toco_conversion_log_before = _toco_conversion_log_pb2.TocoConversionLog() + toco_conversion_log_after = _toco_conversion_log_pb2.TocoConversionLog() + with open(toco_log_before_path, "rb") as f: + toco_conversion_log_before.ParseFromString(f.read()) + with open(toco_log_after_path, "rb") as f: + toco_conversion_log_after.ParseFromString(f.read()) + + # Read the dot file before/after the conversion. + with io.open(dot_before_path, "r", encoding="utf-8") as f: + dot_before = f.read().rstrip() + with io.open(dot_after_path, "r", encoding="utf-8") as f: + dot_after = f.read().rstrip() + + html_generator.generate(toco_conversion_log_before, toco_conversion_log_after, + quantization_enabled, dot_before, dot_after, + toco_conversion_log_after.toco_err_logs) diff --git a/tensorflow/lite/toco/logging/gen_html_test.py b/tensorflow/lite/toco/logging/gen_html_test.py new file mode 100644 index 00000000000..2cabe59bf54 --- /dev/null +++ b/tensorflow/lite/toco/logging/gen_html_test.py @@ -0,0 +1,116 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for gen_html.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil + +from tensorflow.lite.toco.logging import gen_html +from tensorflow.lite.toco.logging import toco_conversion_log_pb2 as _toco_conversion_log_pb2 +from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import file_io as _file_io +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class GenHtmlTest(test_util.TensorFlowTestCase): + + def test_generate_html(self): + toco_conversion_log_before = _toco_conversion_log_pb2.TocoConversionLog() + toco_conversion_log_after = _toco_conversion_log_pb2.TocoConversionLog() + + toco_conversion_log_before.op_list.extend([ + "Conv1", "Conv2", "Identity", "Reshape", "Dense", "Dense", "CustomOp", + "AvgPool3D", "Softmax" + ]) + toco_conversion_log_before.model_size = 9 + + toco_conversion_log_after.op_list.extend([ + "Conv1", "Conv2", "Dense", "Dense", "CustomOp", "AvgPool3D", "Softmax" + ]) + toco_conversion_log_after.built_in_ops["Conv1"] = 1 + toco_conversion_log_after.built_in_ops["Conv2"] = 1 + toco_conversion_log_after.built_in_ops["Dense"] = 2 + toco_conversion_log_after.built_in_ops["Softmax"] = 1 + toco_conversion_log_after.custom_ops["CustomOp"] = 1 + toco_conversion_log_after.select_ops["AvgPool3D"] = 1 + toco_conversion_log_after.model_size = 7 + + export_path = os.path.join(self.get_temp_dir(), "generated.html") + html_generator = gen_html.HTMLGenerator( + html_template_path=resource_loader.get_path_to_datafile( + "template.html"), + export_report_path=export_path) + + html_generator.generate(toco_conversion_log_before, + toco_conversion_log_after, True, + "digraph {a -> b}", "digraph {a -> b}") + + with _file_io.FileIO(export_path, "r") as f_export, _file_io.FileIO( + resource_loader.get_path_to_datafile("testdata/generated.html"), + "r") as f_expect: + expected = f_expect.read() + exported = f_export.read() + self.assertEqual(exported, expected) + + def test_gen_conversion_log_html(self): + # Copies all required data files into a temporary folder for testing. + export_path = self.get_temp_dir() + toco_log_before_path = resource_loader.get_path_to_datafile( + "testdata/toco_log_before.pb") + toco_log_after_path = resource_loader.get_path_to_datafile( + "testdata/toco_log_after.pb") + dot_before = resource_loader.get_path_to_datafile( + "testdata/toco_tf_graph.dot") + dot_after = resource_loader.get_path_to_datafile( + "testdata/toco_tflite_graph.dot") + shutil.copy(toco_log_before_path, export_path) + shutil.copy(toco_log_after_path, export_path) + shutil.copy(dot_before, export_path) + shutil.copy(dot_after, export_path) + + # Generate HTML content based on files in the test folder. + gen_html.gen_conversion_log_html(export_path, True) + + result_html = os.path.join(export_path, "toco_conversion_summary.html") + + with _file_io.FileIO(result_html, "r") as f_export, _file_io.FileIO( + resource_loader.get_path_to_datafile("testdata/generated.html"), + "r") as f_expect: + expected = f_expect.read() + exported = f_export.read() + self.assertEqual(exported, expected) + + def test_get_input_type_from_signature(self): + op_signatures = [ + ("INPUT:[1,73,73,160]::float::[64,1,1,160]::float::[64]::float::" + "OUTPUT:[1,73,73,64]::float::NAME:Conv::VERSION:1") + ] + expect_input_types = [ + ("shape:[1,73,73,160],type:float,shape:[64,1,1,160],type:float," + "shape:[64],type:float") + ] + for i in range(len(op_signatures)): + self.assertEqual( + gen_html.get_input_type_from_signature(op_signatures[i]), + expect_input_types[i]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/lite/toco/logging/template.html b/tensorflow/lite/toco/logging/template.html new file mode 100644 index 00000000000..d1a7f216b81 --- /dev/null +++ b/tensorflow/lite/toco/logging/template.html @@ -0,0 +1,163 @@ + + + + + + + + + + +Tensorflow Conversion Tooling & Logging + + + + + + + + + +
+ + + + + +
+ + + + +

Conversion Summary + +

+ +
    +
  • Total ops used before conversion:
  • +
  • Total ops used post conversion: + Built-in Ops + Select Ops + Custom Ops +
  • +
  • Post-training quantization applied.
  • +
+ +

Operator Profile

+ +
+ +
+ + + + + + + + + + + + + + +
Pre-conversion Op Profile
Operations Used + Input Types +
+
+
+ +
+
+ + + + + + + + + + + + + + +
Post-conversion Op Profile
Operations Used + Supported +
+
+
+ +

Graph Visualization

+
+
+
+ + + + + + + + + + + + + + + + + + + + + + +
Compare the model before/after conversion
Before Conversion + After Conversion +
+
+
+ +

Conversion Log

+
+
+    
+  
+
+
+
+ + + + + diff --git a/tensorflow/lite/toco/logging/testdata/BUILD b/tensorflow/lite/toco/logging/testdata/BUILD new file mode 100644 index 00000000000..fe638cb5d7a --- /dev/null +++ b/tensorflow/lite/toco/logging/testdata/BUILD @@ -0,0 +1,6 @@ +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(glob(["*.html"]) + glob(["*.pb"]) + glob(["*.dot"])) diff --git a/tensorflow/lite/toco/logging/testdata/generated.html b/tensorflow/lite/toco/logging/testdata/generated.html new file mode 100644 index 00000000000..1673ad879f9 --- /dev/null +++ b/tensorflow/lite/toco/logging/testdata/generated.html @@ -0,0 +1,163 @@ + + + + + + + + + + +Tensorflow Conversion Tooling & Logging + + + + + + + + + +
+ + + + + +
+ + + + +

Conversion Summary + Success +

+ +
    +
  • Total ops used before conversion: 9
  • +
  • Total ops used post conversion: 7 + 5 Built-in Ops + 1 Select Ops + 1 Custom Ops +
  • +
  • Post-training quantization is applied.
  • +
+ +

Operator Profile

+ +
+ +
+ + + + + + + + + + + + + + +
Pre-conversion Op Profile
Operations Used + Input Types +
Conv1
Conv2
Identity
Reshape
Dense
Dense
CustomOp
AvgPool3D
Softmax
+
+
+ +
+
+ + + + + + + + + + + + + + +
Post-conversion Op Profile
Operations Used + Supported +
Conv1BUILT-IN
Conv2BUILT-IN
DenseBUILT-IN
DenseBUILT-IN
CustomOpCUSTOM OP
AvgPool3DSELECT OP
SoftmaxBUILT-IN
+
+
+ +

Graph Visualization

+
+
+
+ + + + + + + + + + + + + + + + + + + + + + +
Compare the model before/after conversion
Before Conversion + After Conversion +
+
+
+ +

Conversion Log

+
+
+    
+  
+
+
+
+ + + + + diff --git a/tensorflow/lite/toco/logging/testdata/toco_log_after.pb b/tensorflow/lite/toco/logging/testdata/toco_log_after.pb new file mode 100644 index 00000000000..10c08fae24b --- /dev/null +++ b/tensorflow/lite/toco/logging/testdata/toco_log_after.pb @@ -0,0 +1,14 @@ + +Conv1 +Conv2 +Dense +Dense +CustomOp + AvgPool3D +Softmax +Softmax +Conv2 +Conv1 +Dense +CustomOp" + AvgPool3DH \ No newline at end of file diff --git a/tensorflow/lite/toco/logging/testdata/toco_log_before.pb b/tensorflow/lite/toco/logging/testdata/toco_log_before.pb new file mode 100644 index 00000000000..f1f0f32a198 --- /dev/null +++ b/tensorflow/lite/toco/logging/testdata/toco_log_before.pb @@ -0,0 +1,10 @@ + +Conv1 +Conv2 +Identity +Reshape +Dense +Dense +CustomOp + AvgPool3D +SoftmaxH \ No newline at end of file diff --git a/tensorflow/lite/toco/logging/testdata/toco_tf_graph.dot b/tensorflow/lite/toco/logging/testdata/toco_tf_graph.dot new file mode 100644 index 00000000000..aff70fa4a0d --- /dev/null +++ b/tensorflow/lite/toco/logging/testdata/toco_tf_graph.dot @@ -0,0 +1 @@ +digraph {a -> b} diff --git a/tensorflow/lite/toco/logging/testdata/toco_tflite_graph.dot b/tensorflow/lite/toco/logging/testdata/toco_tflite_graph.dot new file mode 100644 index 00000000000..aff70fa4a0d --- /dev/null +++ b/tensorflow/lite/toco/logging/testdata/toco_tflite_graph.dot @@ -0,0 +1 @@ +digraph {a -> b} diff --git a/tensorflow/lite/toco/logging/toco_conversion_log.proto b/tensorflow/lite/toco/logging/toco_conversion_log.proto new file mode 100644 index 00000000000..3da9affb35e --- /dev/null +++ b/tensorflow/lite/toco/logging/toco_conversion_log.proto @@ -0,0 +1,50 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto2"; + +package toco; + +// TocoConversionLog contains the analytics to be gathered when user converts +// a model to TF Lite using TOCO. +// Next ID to USE: 14. +message TocoConversionLog { + // Total ops listed by name. + repeated string op_list = 1; + // Counts of built-in ops. + // Key is op name and value is the count. + map built_in_ops = 2; + // Counts of custom ops. + map custom_ops = 3; + // Counts of select ops. + map select_ops = 4; + // The signature of operators. Including ops input/output types and shapes, + // op name and version. + repeated string op_signatures = 5; + // Input tensor types. + repeated string input_tensor_types = 6; + // Output tensor types. + repeated string output_tensor_types = 7; + // Log generation time in micro-seconds. + optional int64 log_generation_ts = 8; + // Total number of ops in the model. + optional int32 model_size = 9; + // Tensorflow Lite runtime version. + optional string tf_lite_version = 10; + // Operating System info. + optional string os_version = 11; + // Model hash string. + optional string model_hash = 12; + // Error messages emitted by TOCO during conversion. + optional string toco_err_logs = 13; +} diff --git a/tensorflow/lite/toco/model_cmdline_flags.cc b/tensorflow/lite/toco/model_cmdline_flags.cc index 7e48bd9542b..2434481272f 100644 --- a/tensorflow/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/lite/toco/model_cmdline_flags.cc @@ -132,6 +132,9 @@ bool ParseModelFlagsFromCommandLineFlags( parsed_flags.dump_graphviz_video.default_value(), "If true, will dump graphviz at each " "graph transformation, which may be used to generate a video."), + Flag("conversion_summary_dir", parsed_flags.conversion_summary_dir.bind(), + parsed_flags.conversion_summary_dir.default_value(), + "Local file directory to store the conversion logs."), Flag("allow_nonexistent_arrays", parsed_flags.allow_nonexistent_arrays.bind(), parsed_flags.allow_nonexistent_arrays.default_value(), diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index ef3533778b9..eee54e33398 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -34,15 +34,19 @@ cc_library( "//tensorflow/python:__subpackages__", ], deps = [ + "@com_google_protobuf//:protobuf_headers", "//third_party/python_runtime:headers", "//tensorflow/core:lib", "//tensorflow/lite/python/interpreter_wrapper:python_utils", + "//tensorflow/lite/toco/logging:conversion_log_util", "//tensorflow/lite/toco:model_flags_proto", "//tensorflow/lite/toco:toco_convert", + "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc", "//tensorflow/lite/toco:toco_flags_proto", "//tensorflow/lite/toco:toco_graphviz_dump_options", "//tensorflow/lite/toco:toco_port", "//tensorflow/lite/toco:toco_tooling", + "//tensorflow/lite/toco:tooling_util", "//tensorflow/core:protos_all", "//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer", ] + select({ diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index ee85546ccae..01000f590c1 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/toco/python/toco_python_api.h" +#include #include #include #include +#include "google/protobuf/text_format.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/lite/toco/import_tensorflow.h" +#include "tensorflow/lite/toco/logging/conversion_log_util.h" +#include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_convert.h" #include "tensorflow/lite/toco/toco_flags.pb.h" @@ -29,9 +33,48 @@ limitations under the License. #include "tensorflow/lite/toco/toco_port.h" #include "tensorflow/lite/toco/toco_tooling.h" #include "tensorflow/lite/toco/toco_types.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { +void PopulateConversionLogHelper(const toco::ModelFlags& model_flags, + toco::TocoFlags* toco_flags, + const string& input_contents_txt, + const string& output_file_contents_txt, + const string& error_message, + GraphVizDumpOptions* dump_options) { + // Make sure the graphviz file will be dumped under the same folder. + dump_options->dump_graphviz = toco_flags->conversion_summary_dir(); + // Here we construct the `toco::Model` class based on the input graph def, + // it will then be used to populate the conversion log. + // TODO(haoliang): Don't depend on `toco::Model`. + std::unique_ptr imported_model = + toco::Import(*toco_flags, model_flags, input_contents_txt); + // Dump pre-conversion toco logs. + TocoConversionLog toco_log_before; + PopulateConversionLog(*imported_model, &toco_log_before); + std::ofstream osstream_before(toco_flags->conversion_summary_dir() + + "/toco_log_before.pb"); + toco_log_before.SerializeToOstream(&osstream_before); + osstream_before.close(); + toco::LogDump(toco::kLogLevelModelChanged, "tf_graph", *imported_model); + + // Populate the post-conversion log, for convenient initiate the + // `toco::Model` class from the generated flatbuffer. + toco_flags->set_input_format(toco::FileFormat::TFLITE); + std::unique_ptr flatbuffer_model = + toco::Import(*toco_flags, model_flags, output_file_contents_txt); + // Dump post-conversion toco logs. + TocoConversionLog toco_log_after; + PopulateConversionLog(*flatbuffer_model, &toco_log_after); + toco_log_after.set_toco_err_logs(error_message); + std::ofstream ostream_after(toco_flags->conversion_summary_dir() + + "/toco_log_after.pb"); + toco_log_after.SerializeToOstream(&ostream_after); + ostream_after.close(); + toco::LogDump(toco::kLogLevelModelChanged, "tflite_graph", *flatbuffer_model); +} + // NOTE(aselle): We are using raw PyObject's here because we want to make // sure we input and output bytes rather than unicode strings for Python3. PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, @@ -124,6 +167,11 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, status = tensorflow::ConvertGraphDefToTFLiteFlatBuffer( model_flags, toco_flags, debug_info, graph_def, &output_file_contents_txt); + if (!toco_flags.conversion_summary_dir().empty()) { + PopulateConversionLogHelper(model_flags, &toco_flags, input_contents_txt, + output_file_contents_txt, + status.error_message(), &dump_options); + } } else { status = Convert(input_contents_txt, toco_flags, model_flags, &output_file_contents_txt, &arithmetic_ops_count); diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index d17c5f72caa..422f5129412 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -38,7 +38,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 31. +// Next ID to use: 32. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -218,4 +218,8 @@ message TocoFlags { // runtime memory offsets for activation Tensors (with 128 bits alignment) // and error out on models with undetermined Tensor shape. (Default: True) optional bool allow_dynamic_tensors = 30 [default = true]; + + // Full filepath of the folder to dump conversion logs. This includes a global + // view of the conversion process, and user can choose to submit those logs. + optional string conversion_summary_dir = 31; }