Multiple enhancements to graphviz output:
- nodes organized by subgraphs inferred from array names - subgraphs and ops include math op count estimates - better organized labels on all nodes - operators have numbered inputs and outputs when two or more exist - specific node shapes for inputs/outputs - node color scheme unchanged PiperOrigin-RevId: 233446311
This commit is contained in:
parent
734e10d80b
commit
08703e1aad
@ -378,6 +378,7 @@ cc_library(
|
||||
":types_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
|
||||
@ -15,17 +15,21 @@ limitations under the License.
|
||||
#include "tensorflow/lite/toco/dump_graphviz.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "re2/re2.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
|
||||
#include "tensorflow/lite/toco/toco_port.h"
|
||||
#include "tensorflow/lite/toco/toco_types.h"
|
||||
#include "tensorflow/lite/toco/tooling_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using toco::port::AppendF;
|
||||
using toco::port::StringF;
|
||||
@ -33,72 +37,158 @@ using toco::port::StringF;
|
||||
namespace toco {
|
||||
namespace {
|
||||
|
||||
// 'nslimit' is a graphviz (dot) paramater that limits the iterations during
|
||||
// the layout phase. Omitting it allows infinite iterations, causing some
|
||||
// complex graphs to never finish. A value of 125 produces good graphs
|
||||
// while allowing complex graphs to finish.
|
||||
constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
|
||||
nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
|
||||
)CODE";
|
||||
// Note: tooltip's are only supported on SVGs in Chrome.
|
||||
constexpr char kSubgraphFmt[] =
|
||||
R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
|
||||
)CODE";
|
||||
constexpr char kArrayNodeFmt[] =
|
||||
R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
|
||||
)CODE";
|
||||
constexpr char kOpNodeFmt[] =
|
||||
R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
|
||||
)CODE";
|
||||
constexpr char kInputEdgeFmt[] =
|
||||
R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
|
||||
)CODE";
|
||||
constexpr char kOutputEdgeFmt[] =
|
||||
R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
|
||||
)CODE";
|
||||
constexpr char kRNNBackEdgeFmt[] =
|
||||
R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false];
|
||||
)CODE";
|
||||
constexpr char kUnicodeMult[] = "\u00D7";
|
||||
constexpr char kUnicodeEllipsis[] = " \u2026 ";
|
||||
|
||||
class Color {
|
||||
public:
|
||||
Color() {}
|
||||
Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
|
||||
explicit Color(uint32 word)
|
||||
: r_((word & 0x00FF0000) >> 16),
|
||||
g_((word & 0x0000FF00) >> 8),
|
||||
b_((word & 0x000000FF) >> 0) {}
|
||||
|
||||
// Returns the string serialization of this color in graphviz format,
|
||||
// for use as 'fillcolor' in boxes.
|
||||
string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); }
|
||||
string AsHexString() const { return StringF("#%.2X%.2X%.2X", r_, g_, b_); }
|
||||
// The color to use for this node; will be used as 'fillcolor'
|
||||
// for its box. See Color::AsHexString. A suitable, different
|
||||
// color will be chosen for the 'fontcolor' for the inside text
|
||||
// label, see Color::TextColorString.
|
||||
// Returns the serialization in graphviz format of a suitable color to use
|
||||
// 'fontcolor' in the same boxes. It should black or white, whichever offers
|
||||
// the better contrast from FillColorString().
|
||||
// the better contrast from AsHexString().
|
||||
string TextColorString() const {
|
||||
// https://en.wikipedia.org/wiki/Relative_luminance
|
||||
const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
|
||||
const uint8 l = luminance > 128.f ? 0 : 255;
|
||||
return StringF("%.2X%.2X%.2X", l, l, l);
|
||||
return StringF("#%.2X%.2X%.2X", l, l, l);
|
||||
}
|
||||
|
||||
private:
|
||||
uint8 r_ = 0, g_ = 0, b_ = 0;
|
||||
};
|
||||
|
||||
struct NodeProperties {
|
||||
// The text to display inside the box for this node.
|
||||
string label;
|
||||
// The color to use for this node; will be used as 'fillcolor'
|
||||
// for its box. See Color::FillColorString. A suitable, different
|
||||
// color will be chosen for the 'fontcolor' for the inside text
|
||||
// label, see Color::TextColorString.
|
||||
Color color;
|
||||
float log2_buffer_size;
|
||||
};
|
||||
Color HashStringToColor(string s) {
|
||||
// Return a unique color for a name.
|
||||
//
|
||||
// This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
|
||||
// the string to a uint_32, then twiddles some bits to get a light and subtle
|
||||
// color. This seems to be a good heuristic for keeping enough of the name to
|
||||
// hash to a unique color while still revealing structure through naming
|
||||
// similarities.
|
||||
//
|
||||
// The regular expression "_\d+" matches any underscore followed by numbers,
|
||||
// which we strip out. Examples:
|
||||
//
|
||||
// "Conv" -> "Conv"
|
||||
// "Conv_2" -> "Conv"
|
||||
// "Conv_72" -> "Conv"
|
||||
// "Pad_1_bias -> "Pad_bias"
|
||||
// "Conv_abc" -> "Conv_abc"
|
||||
|
||||
// All colors in this file are from:
|
||||
// https://material.io/guidelines/style/color.html
|
||||
RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
|
||||
uint32 color_word = std::hash<std::string>{}(s);
|
||||
color_word |= 0x00E0E0E0;
|
||||
return Color(color_word);
|
||||
}
|
||||
|
||||
Color GetColorForArray(const Model& model, const string& array_name) {
|
||||
void GetArrayColorAndShape(const Model& model, const string& array_name,
|
||||
Color* color, string* shape) {
|
||||
// All colors in this file are from:
|
||||
// https://material.io/guidelines/style/color.html
|
||||
// Arrays involved in RNN back-edges have a different color
|
||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||
// RNN state, fed by a back-edge. Bold color.
|
||||
if (array_name == rnn_state.state_array()) {
|
||||
return Color(0x0F, 0x9D, 0x58);
|
||||
*color = Color(0x0F, 0x9D, 0x58);
|
||||
*shape = "invhouse";
|
||||
return;
|
||||
}
|
||||
// RNN back-edge source, feeding a RNN state.
|
||||
// Light tone of the same color as RNN states.
|
||||
if (array_name == rnn_state.back_edge_source_array()) {
|
||||
return Color(0xB7, 0xE1, 0xCD);
|
||||
*color = Color(0xB7, 0xE1, 0xCD);
|
||||
*shape = "house";
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Constant parameter arrays have their own bold color
|
||||
if (model.GetArray(array_name).buffer) {
|
||||
return Color(0x42, 0x85, 0xF4);
|
||||
*color = Color(0x42, 0x85, 0xF4);
|
||||
*shape = "cylinder";
|
||||
return;
|
||||
}
|
||||
// Remaining arrays are activations.
|
||||
// We use gray colors for them because they are the majority
|
||||
// of arrays so we want to highlight other arrays instead of them.
|
||||
// First, we use a bolder gray for input/output arrays:
|
||||
if (IsInputArray(model, array_name)) {
|
||||
return Color(0x9E, 0x9E, 0x9E);
|
||||
*color = Color(0x9E, 0x9E, 0x9E);
|
||||
*shape = "invhouse";
|
||||
return;
|
||||
}
|
||||
if (IsOutputArray(model, array_name)) {
|
||||
return Color(0x9E, 0x9E, 0x9E);
|
||||
*color = Color(0x9E, 0x9E, 0x9E);
|
||||
*shape = "house";
|
||||
return;
|
||||
}
|
||||
// Remaining arrays are intermediate activation arrays.
|
||||
// Lighter tone of the same grey as for input/output arrays:
|
||||
// We want these to be very discrete.
|
||||
return Color(0xF5, 0xF5, 0xF5);
|
||||
*color = Color(0xF5, 0xF5, 0xF5);
|
||||
*shape = "box";
|
||||
}
|
||||
|
||||
string GetArrayCompassPt(const Model& model, const string& array_name) {
|
||||
// The "compass point" is the point on the node where edge connections are
|
||||
// made. For most arrays we don't care, but input's and outputs look better
|
||||
// connected at the tip of the "house" and "invhouse" shapes used. So we
|
||||
// append ":n" and ":s" respectively for those.
|
||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||
// RNN state is essentially an input
|
||||
if (array_name == rnn_state.state_array()) {
|
||||
return ":s";
|
||||
}
|
||||
// RNN back-edge source is essentially an output
|
||||
if (array_name == rnn_state.back_edge_source_array()) {
|
||||
return ":n";
|
||||
}
|
||||
}
|
||||
if (IsInputArray(model, array_name)) {
|
||||
return ":s";
|
||||
}
|
||||
if (IsOutputArray(model, array_name)) {
|
||||
return ":n";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
void AppendArrayVal(string* string, Array const& array, int index) {
|
||||
@ -141,239 +231,550 @@ void AppendArrayVal(string* string, Array const& array, int index) {
|
||||
}
|
||||
}
|
||||
|
||||
NodeProperties GetPropertiesForArray(const Model& model,
|
||||
const string& array_name) {
|
||||
NodeProperties node_properties;
|
||||
node_properties.color = GetColorForArray(model, array_name);
|
||||
node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}});
|
||||
node_properties.log2_buffer_size = 0.0f;
|
||||
typedef std::map<string, string> Attributes;
|
||||
|
||||
// Append array shape to the label.
|
||||
auto& array = model.GetArray(array_name);
|
||||
AppendF(&node_properties.label, "\\nType: %s",
|
||||
ArrayDataTypeName(array.data_type));
|
||||
|
||||
if (array.has_shape()) {
|
||||
auto& array_shape = array.shape();
|
||||
node_properties.label += "\\n[";
|
||||
for (int id = 0; id < array_shape.dimensions_count(); id++) {
|
||||
if (id == 0) {
|
||||
AppendF(&node_properties.label, "%d", array_shape.dims(id));
|
||||
} else {
|
||||
// 0x00D7 is the unicode multiplication symbol
|
||||
AppendF(&node_properties.label, "\u00D7%d", array_shape.dims(id));
|
||||
}
|
||||
}
|
||||
node_properties.label += "]";
|
||||
|
||||
int buffer_size = 0;
|
||||
if (IsNonEmpty(array.shape())) {
|
||||
buffer_size = RequiredBufferSizeForShape(array.shape());
|
||||
node_properties.log2_buffer_size =
|
||||
std::log2(static_cast<float>(buffer_size));
|
||||
}
|
||||
|
||||
if (array.buffer) {
|
||||
const auto& array = model.GetArray(array_name);
|
||||
if (buffer_size <= 4) {
|
||||
AppendF(&node_properties.label, " = ");
|
||||
if (array.shape().dimensions_count() > 0) {
|
||||
AppendF(&node_properties.label, "{");
|
||||
}
|
||||
for (int i = 0; i < buffer_size; i++) {
|
||||
AppendArrayVal(&node_properties.label, array, i);
|
||||
if (i + 1 < buffer_size) {
|
||||
AppendF(&node_properties.label, ", ");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
AppendF(&node_properties.label, "\\n = ");
|
||||
if (array.shape().dimensions_count() > 0) {
|
||||
AppendF(&node_properties.label, "{");
|
||||
}
|
||||
AppendArrayVal(&node_properties.label, array, 0);
|
||||
AppendF(&node_properties.label, ", ");
|
||||
AppendArrayVal(&node_properties.label, array, 1);
|
||||
// 0x2026 is the unicode ellipsis symbol
|
||||
AppendF(&node_properties.label, " \u2026 ");
|
||||
AppendArrayVal(&node_properties.label, array, buffer_size - 2);
|
||||
AppendF(&node_properties.label, ", ");
|
||||
AppendArrayVal(&node_properties.label, array, buffer_size - 1);
|
||||
}
|
||||
if (array.shape().dimensions_count() > 0) {
|
||||
AppendF(&node_properties.label, "}");
|
||||
}
|
||||
}
|
||||
string AttributesToHtml(Attributes attributes) {
|
||||
string html;
|
||||
for (const auto& attr : attributes) {
|
||||
html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
|
||||
html += attr.first;
|
||||
html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
|
||||
html += attr.second;
|
||||
html += "</TD></TR>";
|
||||
}
|
||||
|
||||
if (array.minmax) {
|
||||
AppendF(&node_properties.label, "\\nMinMax: [%.7g, %.7g]",
|
||||
array.minmax->min, array.minmax->max);
|
||||
}
|
||||
|
||||
if (array.quantization_params) {
|
||||
AppendF(&node_properties.label, "\\nQuantization: %7g * (x - %d)",
|
||||
array.quantization_params->scale,
|
||||
array.quantization_params->zero_point);
|
||||
}
|
||||
|
||||
if (array.alloc) {
|
||||
AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)",
|
||||
array.alloc->start, array.alloc->end);
|
||||
}
|
||||
|
||||
return node_properties;
|
||||
return html;
|
||||
}
|
||||
|
||||
NodeProperties GetPropertiesForOperator(const Operator& op) {
|
||||
NodeProperties node_properties;
|
||||
if (op.type == OperatorType::kUnsupported) {
|
||||
node_properties.label =
|
||||
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
|
||||
} else {
|
||||
node_properties.label =
|
||||
string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
|
||||
string GetArrayLabel(const Model& model, const string& array_id) {
|
||||
string html;
|
||||
|
||||
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
|
||||
html += "<";
|
||||
|
||||
// Begin Table
|
||||
html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
|
||||
html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
|
||||
|
||||
auto& array = model.GetArray(array_id);
|
||||
if (array.buffer) {
|
||||
// "cylinder" shapes require some extra head room.
|
||||
html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
|
||||
}
|
||||
|
||||
// "Primary" name of array (last non-slash delimited group of characters).
|
||||
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
|
||||
html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
|
||||
AppendF(&html, R"CODE(%s)CODE",
|
||||
std::vector<string>(absl::StrSplit(array_id, '/')).back());
|
||||
html += R"CODE(</I></FONT>)CODE";
|
||||
html += "</TD></TR>";
|
||||
|
||||
// Array data type and dimensions
|
||||
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
|
||||
html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
|
||||
// Type
|
||||
html += ArrayDataTypeName(array.data_type);
|
||||
// Shape
|
||||
if (array.has_shape()) {
|
||||
auto& array_shape = array.shape();
|
||||
html += "[";
|
||||
for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
|
||||
AppendF(&html, "%d", array_shape.dims(dim));
|
||||
if (dim + 1 < array_shape.dimensions_count()) {
|
||||
html += kUnicodeMult;
|
||||
}
|
||||
}
|
||||
html += "]";
|
||||
}
|
||||
|
||||
// Small buffer sample
|
||||
int buffer_size = 0;
|
||||
if (array.buffer) {
|
||||
buffer_size = RequiredBufferSizeForShape(array.shape());
|
||||
}
|
||||
if ((buffer_size > 0) && (buffer_size <= 4)) {
|
||||
html += " = ";
|
||||
if (array.shape().dimensions_count() > 0) {
|
||||
html += "{";
|
||||
}
|
||||
for (int i = 0; i < buffer_size; i++) {
|
||||
AppendArrayVal(&html, array, i);
|
||||
if (i + 1 < buffer_size) {
|
||||
html += ", ";
|
||||
}
|
||||
}
|
||||
if (array.shape().dimensions_count() > 0) {
|
||||
html += "}";
|
||||
}
|
||||
}
|
||||
html += R"CODE(</B></FONT>)CODE";
|
||||
html += "</TD></TR>";
|
||||
|
||||
// Large buffer samples get their own line
|
||||
if (buffer_size > 4) {
|
||||
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
|
||||
AppendArrayVal(&html, array, 0);
|
||||
html += ", ";
|
||||
AppendArrayVal(&html, array, 1);
|
||||
html += kUnicodeEllipsis;
|
||||
AppendArrayVal(&html, array, buffer_size - 2);
|
||||
html += ", ";
|
||||
AppendArrayVal(&html, array, buffer_size - 1);
|
||||
html += "}</TD></TR>";
|
||||
}
|
||||
|
||||
// Other array properties
|
||||
Attributes attrs;
|
||||
if (array.minmax) {
|
||||
attrs["minmax"] =
|
||||
StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
|
||||
}
|
||||
if (array.quantization_params) {
|
||||
attrs["quant"] = StringF("%7g\u00B7(x-%d)", // Unicode "cdot"
|
||||
array.quantization_params->scale,
|
||||
array.quantization_params->zero_point);
|
||||
}
|
||||
if (array.alloc) {
|
||||
attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
|
||||
}
|
||||
html += AttributesToHtml(attrs);
|
||||
|
||||
// output array_id in ultra-small font so it can be searched and copied.
|
||||
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
|
||||
html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
|
||||
AppendF(&html, R"CODE("%s")CODE", array_id);
|
||||
html += R"CODE(</FONT>)CODE";
|
||||
html += "</TD></TR>";
|
||||
|
||||
// End Table and HTML-like label
|
||||
html += R"CODE(</TABLE></FONT>)CODE";
|
||||
html += ">";
|
||||
return html;
|
||||
}
|
||||
|
||||
Attributes GetOpAttributes(const Model& model, const Operator& op) {
|
||||
Attributes attrs;
|
||||
switch (op.fused_activation_function) {
|
||||
case FusedActivationFunctionType::kRelu:
|
||||
AppendF(&node_properties.label, "\\nReLU");
|
||||
attrs["func"] = "ReLU";
|
||||
break;
|
||||
case FusedActivationFunctionType::kRelu6:
|
||||
AppendF(&node_properties.label, "\\nReLU6");
|
||||
attrs["func"] = "ReLU6";
|
||||
break;
|
||||
case FusedActivationFunctionType::kRelu1:
|
||||
AppendF(&node_properties.label, "\\nReLU1");
|
||||
attrs["func"] = "ReLU1";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
// Additional information for some of the operators.
|
||||
// Output state of member vars on derived operators.
|
||||
switch (op.type) {
|
||||
case OperatorType::kConv: {
|
||||
const auto& conv_op = static_cast<const ConvOperator&>(op);
|
||||
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
|
||||
AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
|
||||
conv_op.stride_height,
|
||||
conv_op.padding.type == PaddingType::kSame ? "S" : "V");
|
||||
string stride;
|
||||
AppendF(&stride, "%d", conv_op.stride_width);
|
||||
stride += kUnicodeMult;
|
||||
AppendF(&stride, "%d", conv_op.stride_height);
|
||||
attrs["stride"] = stride;
|
||||
attrs["padding"] =
|
||||
(conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
|
||||
break;
|
||||
}
|
||||
case OperatorType::kDepthwiseConv: {
|
||||
const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
|
||||
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
|
||||
AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
|
||||
conv_op.stride_height,
|
||||
conv_op.padding.type == PaddingType::kSame ? "S" : "V");
|
||||
break;
|
||||
}
|
||||
case OperatorType::kFullyConnected: {
|
||||
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
|
||||
const auto& depthconv_op = static_cast<const ConvOperator&>(op);
|
||||
string stride;
|
||||
AppendF(&stride, "%d", depthconv_op.stride_width);
|
||||
stride += kUnicodeMult;
|
||||
AppendF(&stride, "%d", depthconv_op.stride_height);
|
||||
attrs["stride"] = stride;
|
||||
attrs["padding"] =
|
||||
(depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
|
||||
break;
|
||||
}
|
||||
case OperatorType::kFakeQuant: {
|
||||
const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
|
||||
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
|
||||
attrs["bits"] = StringF("%d", fakequant_op.num_bits);
|
||||
if (fakequant_op.minmax) {
|
||||
AppendF(&node_properties.label, "\\n%dbit [%g,%g]",
|
||||
fakequant_op.num_bits, fakequant_op.minmax->min,
|
||||
fakequant_op.minmax->max);
|
||||
attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
|
||||
fakequant_op.minmax->max);
|
||||
} else {
|
||||
AppendF(&node_properties.label, "\\n%dbit [?,?]",
|
||||
fakequant_op.num_bits);
|
||||
attrs["range"] = "[?,?]";
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
node_properties.color = Color(0xDB, 0x44, 0x37);
|
||||
break;
|
||||
}
|
||||
int64 math_ops_count;
|
||||
if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
|
||||
(math_ops_count != 0)) {
|
||||
attrs["math"] = FormattedNumber(math_ops_count) + "ops";
|
||||
}
|
||||
|
||||
return node_properties;
|
||||
return attrs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void DumpGraphviz(const Model& model, string* output_file_contents) {
|
||||
AppendF(output_file_contents, "digraph Computegraph {\n");
|
||||
// 'nslimit' is a graphviz (dot) paramater that limits the iterations during
|
||||
// the layout phase. Omitting it allows infinite iterations, causing some
|
||||
// complex graphs to never finish. A value of 125 produces good graphs
|
||||
// while allowing complex graphs to finish.
|
||||
AppendF(output_file_contents, "\t nslimit=125;\n");
|
||||
|
||||
constexpr char kNodeFormat[] =
|
||||
"\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", "
|
||||
"fontcolor = \"#%sDD\"];\n";
|
||||
|
||||
constexpr char kEdgeFormat[] =
|
||||
"\t \"%s\" -> \"%s\" [penwidth=%f, weight=%f];\n";
|
||||
|
||||
constexpr char kRNNBackEdgeFormat[] =
|
||||
"\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
|
||||
|
||||
for (const auto& array_kv : model.GetArrayMap()) {
|
||||
// Add node for array.
|
||||
const string& array_name = array_kv.first;
|
||||
const auto& array_properties = GetPropertiesForArray(model, array_name);
|
||||
AppendF(output_file_contents, kNodeFormat, array_name,
|
||||
array_properties.label, "octagon",
|
||||
array_properties.color.FillColorString().c_str(),
|
||||
array_properties.color.TextColorString().c_str());
|
||||
Color GetOpColor(const Operator& op) {
|
||||
if ((op.type == OperatorType::kDepthwiseConv) ||
|
||||
(op.type == OperatorType::kConv) ||
|
||||
(op.type == OperatorType::kFullyConnected) ||
|
||||
(op.type == OperatorType::kFakeQuant)) {
|
||||
// Give some ops a bolder red
|
||||
return Color(0xC5, 0x39, 0x29);
|
||||
} else {
|
||||
return Color(0xDB, 0x44, 0x37);
|
||||
}
|
||||
}
|
||||
|
||||
string GetOpLabel(const Model& model, const Operator& op) {
|
||||
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
|
||||
string html;
|
||||
html += "<";
|
||||
|
||||
// Begin Table
|
||||
html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
|
||||
html +=
|
||||
R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
|
||||
|
||||
// Input Ports
|
||||
if (!op.inputs.empty()) {
|
||||
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
|
||||
// Distribute evenly using a sub-table
|
||||
html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
|
||||
html += R"CODE(<TR>)CODE";
|
||||
for (int i = 0; i < op.inputs.size(); i++) {
|
||||
html += R"CODE(<TD PORT=")CODE";
|
||||
AppendF(&html, "i%d", i);
|
||||
html += R"CODE(">)CODE";
|
||||
if (op.inputs.size() > 1) {
|
||||
// Only number inputs when op has two or more inputs
|
||||
AppendF(&html, "%d", i);
|
||||
}
|
||||
html += "</TD>";
|
||||
}
|
||||
html += "</TR>";
|
||||
html += R"CODE(</TABLE></TD></TR>)CODE";
|
||||
}
|
||||
|
||||
// Name
|
||||
html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
|
||||
html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
|
||||
if (op.type == OperatorType::kUnsupported) {
|
||||
html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
|
||||
} else {
|
||||
html += string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
|
||||
}
|
||||
html += R"CODE(</B></FONT>)CODE";
|
||||
html += "</TD></TR>";
|
||||
|
||||
// Attributes
|
||||
Attributes attrs = GetOpAttributes(model, op);
|
||||
html += AttributesToHtml(attrs);
|
||||
|
||||
// Output Ports
|
||||
if (!op.outputs.empty()) {
|
||||
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
|
||||
// Distribute evenly using a sub-table
|
||||
html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
|
||||
html += R"CODE(<TR>)CODE";
|
||||
for (int i = 0; i < op.outputs.size(); i++) {
|
||||
html += R"CODE(<TD PORT=")CODE";
|
||||
AppendF(&html, "o%d", i);
|
||||
html += R"CODE(">)CODE";
|
||||
if (op.outputs.size() > 1) {
|
||||
// Only number outputs when op has two or more outputs
|
||||
AppendF(&html, "%d", i);
|
||||
}
|
||||
html += "</TD>";
|
||||
}
|
||||
html += "</TR>";
|
||||
html += R"CODE(</TABLE></TD></TR>)CODE";
|
||||
}
|
||||
|
||||
// End Table and HTML-like label
|
||||
html += R"CODE(</TABLE></FONT>)CODE";
|
||||
html += ">";
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
float GetLog2BufferSize(const Model& model, const string& array_id) {
|
||||
auto& array = model.GetArray(array_id);
|
||||
if (array.has_shape()) {
|
||||
int buffer_size = 0;
|
||||
if (IsNonEmpty(array.shape())) {
|
||||
buffer_size = RequiredBufferSizeForShape(array.shape());
|
||||
return std::log2(static_cast<float>(buffer_size));
|
||||
}
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
string GetOpId(int op_index) { return StringF("op%05d", op_index); }
|
||||
|
||||
void DumpOperator(const Model& model, string* output_file, int op_index) {
|
||||
// Dump node for operator.
|
||||
const Operator& op = *model.operators[op_index];
|
||||
Color color = GetOpColor(op);
|
||||
string label = GetOpLabel(model, op);
|
||||
string op_id = GetOpId(op_index);
|
||||
AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
|
||||
color.TextColorString());
|
||||
}
|
||||
|
||||
void DumpOperatorEdges(const Model& model, string* output_file, int op_index) {
|
||||
// Inputs
|
||||
const Operator& op = *model.operators[op_index];
|
||||
string op_id = GetOpId(op_index);
|
||||
for (int i = 0; i < op.inputs.size(); i++) {
|
||||
const auto& input = op.inputs[i];
|
||||
if (!model.HasArray(input)) {
|
||||
// Connected arrays should _always_ exist. Except, perhaps, during
|
||||
// development.
|
||||
continue;
|
||||
}
|
||||
float log2_buffer_size = GetLog2BufferSize(model, input);
|
||||
// Draw lines that transport more data thicker (Otherwise, where would the
|
||||
// data fit? right?).
|
||||
float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
|
||||
// Keep edges that transport more data shorter than those with less.
|
||||
float weight = std::max(1.0f, log2_buffer_size);
|
||||
if (!IsInputArray(model, input) &&
|
||||
GetOpWithOutput(model, input) == nullptr) {
|
||||
// Give the main line of data flow a straighter path by penalizing edges
|
||||
// to standalone buffers. Weights are generally very large buffers that
|
||||
// would otherwise skew the layout.
|
||||
weight = 1.0f;
|
||||
}
|
||||
string compass_pt = GetArrayCompassPt(model, input);
|
||||
AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
|
||||
weight);
|
||||
}
|
||||
// Outputs
|
||||
for (int i = 0; i < op.outputs.size(); i++) {
|
||||
const auto& output = op.outputs[i];
|
||||
if (!model.HasArray(output)) {
|
||||
continue;
|
||||
}
|
||||
float log2_buffer_size = GetLog2BufferSize(model, output);
|
||||
// See comments above regarding weight and line_width calculations.
|
||||
float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
|
||||
float weight = std::max(1.0f, log2_buffer_size);
|
||||
if (!IsArrayConsumed(model, output)) {
|
||||
weight = 1.0f;
|
||||
}
|
||||
string compass_pt = GetArrayCompassPt(model, output);
|
||||
AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
|
||||
line_width, weight);
|
||||
}
|
||||
}
|
||||
|
||||
struct Node {
|
||||
Node() : math_ops(0) {}
|
||||
// Name used as a key in the model's array map
|
||||
string array_id;
|
||||
|
||||
// Estimated number of math ops incurred by this node (the sum of the op
|
||||
// with this array as 1st output, plus all children nodes).
|
||||
int64 math_ops;
|
||||
|
||||
// A map of child nodes keyed by name.
|
||||
std::map<const string, std::unique_ptr<Node>> children;
|
||||
};
|
||||
|
||||
string GetSubgraphLabel(Node const& node, const string& subgraph) {
|
||||
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
|
||||
string html;
|
||||
html += "<";
|
||||
|
||||
// Begin Table
|
||||
html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
|
||||
html +=
|
||||
R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
|
||||
|
||||
// Name
|
||||
html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
|
||||
html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
|
||||
html += subgraph;
|
||||
html += R"CODE(</I></FONT>)CODE";
|
||||
html += "</TD></TR>";
|
||||
|
||||
// Attributes
|
||||
Attributes attrs;
|
||||
if (node.math_ops > 0) {
|
||||
attrs["math"] = FormattedNumber(node.math_ops) + "ops";
|
||||
}
|
||||
html += AttributesToHtml(attrs);
|
||||
|
||||
// End Table and HTML-like label
|
||||
html += R"CODE(</TABLE></FONT>)CODE";
|
||||
html += ">";
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
void DumpSubgraphHeader(string* output_file, Node const& node,
|
||||
const string& node_name) {
|
||||
Color color = HashStringToColor(node_name);
|
||||
string label = GetSubgraphLabel(node, node_name);
|
||||
AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
|
||||
}
|
||||
|
||||
void DumpArray(const Model& model, string* output_file,
|
||||
const string& array_id) {
|
||||
Color color;
|
||||
string shape;
|
||||
GetArrayColorAndShape(model, array_id, &color, &shape);
|
||||
string label = GetArrayLabel(model, array_id);
|
||||
AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
|
||||
color.AsHexString(), color.TextColorString());
|
||||
|
||||
// Ops are placed in the same subgraph as their first output.
|
||||
for (int op_index = 0; op_index < model.operators.size(); op_index++) {
|
||||
const Operator& op = *model.operators[op_index];
|
||||
// Add node for operator.
|
||||
auto op_properties = GetPropertiesForOperator(op);
|
||||
string operator_id = StringF("op%05d", op_index);
|
||||
AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label,
|
||||
"box", op_properties.color.FillColorString().c_str(),
|
||||
op_properties.color.TextColorString().c_str());
|
||||
// Add edges for all inputs of the operator.
|
||||
for (const auto& input : op.inputs) {
|
||||
if (!model.HasArray(input)) {
|
||||
// Arrays should _always_ exist. Except, perhaps, during development.
|
||||
continue;
|
||||
}
|
||||
auto array_properties = GetPropertiesForArray(model, input);
|
||||
// Draw lines that transport more data thicker (Otherwise, where would the
|
||||
// data fit? right?).
|
||||
float line_width =
|
||||
std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
|
||||
// Keep edges that transport more data shorter than those with less.
|
||||
float weight = std::max(1.0f, array_properties.log2_buffer_size);
|
||||
if (!IsInputArray(model, input) &&
|
||||
GetOpWithOutput(model, input) == nullptr) {
|
||||
// Give the main line of data flow a straighter path by penalizing edges
|
||||
// to standalone buffers. Weights are generally very large buffers that
|
||||
// otherwise skew the layout without this.
|
||||
weight = 1.0f;
|
||||
}
|
||||
AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width,
|
||||
weight);
|
||||
}
|
||||
// Add edges for all outputs of the operator.
|
||||
for (const auto& output : op.outputs) {
|
||||
if (!model.HasArray(output)) {
|
||||
// Arrays should _always_ exist. Except, perhaps, during development.
|
||||
continue;
|
||||
}
|
||||
auto array_properties = GetPropertiesForArray(model, output);
|
||||
// See comments above regarding weight and line_width calculations.
|
||||
float line_width =
|
||||
std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
|
||||
float weight = std::max(1.0f, array_properties.log2_buffer_size);
|
||||
if (!IsArrayConsumed(model, output)) {
|
||||
weight = 1.0f;
|
||||
}
|
||||
AppendF(output_file_contents, kEdgeFormat, operator_id, output,
|
||||
line_width, weight);
|
||||
if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
|
||||
DumpOperator(model, output_file, op_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpNode(const Model& model, string* output_file, const string& node_name,
|
||||
Node const& node) {
|
||||
bool not_root = !node_name.empty();
|
||||
if (not_root) {
|
||||
DumpSubgraphHeader(output_file, node, node_name);
|
||||
}
|
||||
|
||||
for (const auto& child : node.children) {
|
||||
if (!child.second->array_id.empty()) {
|
||||
// Dump array if this node posesses one.
|
||||
DumpArray(model, output_file, child.second->array_id);
|
||||
}
|
||||
// Note that it is always possible to have children. Unlike a filesystem,
|
||||
// the existence of array "foo/bar" does _not_ prevent other arrays, such as
|
||||
// and "foo/bar/baz", from being nested beneath it.
|
||||
DumpNode(model, output_file, child.first, *child.second);
|
||||
}
|
||||
|
||||
if (not_root) {
|
||||
// End subgraph
|
||||
AppendF(output_file, " }\n");
|
||||
}
|
||||
}
|
||||
|
||||
int64 GetArithmeticOpsCount(const Model& model, const string& array_id) {
|
||||
for (const auto& op : model.operators) {
|
||||
if (!op->outputs.empty() && op->outputs[0] == array_id) {
|
||||
int64 count;
|
||||
if (EstimateArithmeticOpsCount(model, *op, &count)) {
|
||||
return count;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void InsertNode(const Model& model, const string& array_id, Node* node,
|
||||
std::vector<string> prefixes, int64* math_ops) {
|
||||
if (prefixes.empty()) {
|
||||
// Base case: store array in this node.
|
||||
node->array_id = array_id;
|
||||
*math_ops = GetArithmeticOpsCount(model, array_id);
|
||||
} else {
|
||||
// Insert into the sub-tree for that prefix.
|
||||
string prefix = prefixes.back();
|
||||
prefixes.pop_back();
|
||||
if (node->children.count(prefix) == 0) {
|
||||
// Create a new node if this prefix is unseen.
|
||||
node->children[prefix] = absl::make_unique<Node>();
|
||||
}
|
||||
InsertNode(model, array_id, node->children[prefix].get(), prefixes,
|
||||
math_ops);
|
||||
}
|
||||
// Sum estimated math ops into all nodes.
|
||||
node->math_ops += *math_ops;
|
||||
}
|
||||
|
||||
void BuildArrayTree(const Model& model, Node* tree) {
|
||||
// Delimit array names by path "/", then place into a tree based on this path.
|
||||
for (const auto& array_id : model.GetArrayMap()) {
|
||||
std::vector<string> prefixes = absl::StrSplit(array_id.first, '/');
|
||||
std::reverse(prefixes.begin(), prefixes.end());
|
||||
int64 math_ops; // Temporary storage for math ops used during recursion.
|
||||
InsertNode(model, array_id.first, tree, prefixes, &math_ops);
|
||||
}
|
||||
}
|
||||
|
||||
string GetGraphLabel(const Model& model, const string& graph_name) {
|
||||
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
|
||||
string html;
|
||||
html += "<";
|
||||
|
||||
// Begin Table
|
||||
html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
|
||||
html +=
|
||||
R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
|
||||
|
||||
// Name
|
||||
html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
|
||||
html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
|
||||
html += graph_name;
|
||||
html += R"CODE(</I></B></FONT>)CODE";
|
||||
html += "</TD></TR>";
|
||||
|
||||
// Attributes
|
||||
Attributes attrs;
|
||||
attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
|
||||
if (!model.optional_arrays.empty()) {
|
||||
attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
|
||||
}
|
||||
attrs["operators"] = StringF("%d", model.operators.size());
|
||||
int64 ops_count;
|
||||
if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
|
||||
attrs["math"] = FormattedNumber(ops_count) + "ops";
|
||||
}
|
||||
if (model.transient_data_size > 0) {
|
||||
attrs["transient data size"] =
|
||||
StringF("%d KiB", model.transient_data_size / 1024);
|
||||
}
|
||||
if (model.transient_data_alignment > 0) {
|
||||
attrs["transient data alignment"] =
|
||||
StringF("%d bytes", model.transient_data_alignment);
|
||||
}
|
||||
html += AttributesToHtml(attrs);
|
||||
|
||||
// End Table and HTML-like label
|
||||
html += R"CODE(</TABLE></FONT>)CODE";
|
||||
html += ">";
|
||||
|
||||
return html;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void DumpGraphviz(const Model& model, string* output_file,
|
||||
const string& graph_name) {
|
||||
// Start graphviz format
|
||||
AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
|
||||
|
||||
// Organize arrays into a tree for subgraphing
|
||||
Node tree;
|
||||
BuildArrayTree(model, &tree);
|
||||
DumpNode(model, output_file, "", tree);
|
||||
|
||||
// Dump edges outside all subgraphs (otherwise the referred-to nodes are
|
||||
// implicitly included in that subgraph).
|
||||
for (int op_index = 0; op_index < model.operators.size(); op_index++) {
|
||||
DumpOperatorEdges(model, output_file, op_index);
|
||||
}
|
||||
|
||||
// Dump RNN Backedges
|
||||
for (const auto& rnn_state : model.flags.rnn_states()) {
|
||||
AppendF(output_file_contents, kRNNBackEdgeFormat,
|
||||
rnn_state.back_edge_source_array(), rnn_state.state_array());
|
||||
AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
|
||||
rnn_state.state_array());
|
||||
}
|
||||
|
||||
AppendF(output_file_contents, "}\n");
|
||||
// End graphviz format
|
||||
AppendF(output_file, "}\n");
|
||||
}
|
||||
} // namespace toco
|
||||
|
||||
@ -21,7 +21,8 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
void DumpGraphviz(const Model& model, string* output_file_contents);
|
||||
void DumpGraphviz(const Model& model, string* output_file_contents,
|
||||
const string& graph_name);
|
||||
|
||||
} // namespace toco
|
||||
|
||||
|
||||
@ -454,7 +454,7 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
|
||||
return status;
|
||||
} break;
|
||||
case GRAPHVIZ_DOT:
|
||||
DumpGraphviz(model, output_file_contents);
|
||||
DumpGraphviz(model, output_file_contents, "Computation Graph");
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled output_format='"
|
||||
|
||||
@ -66,29 +66,29 @@ string LogName(const Operator& op) {
|
||||
string ArrayDataTypeName(ArrayDataType data_type) {
|
||||
switch (data_type) {
|
||||
case ArrayDataType::kFloat:
|
||||
return "Float";
|
||||
return "float";
|
||||
case ArrayDataType::kInt8:
|
||||
return "Int8";
|
||||
return "int8";
|
||||
case ArrayDataType::kUint8:
|
||||
return "Uint8";
|
||||
return "uint8";
|
||||
case ArrayDataType::kInt16:
|
||||
return "Int16";
|
||||
return "int16";
|
||||
case ArrayDataType::kUint16:
|
||||
return "Uint16";
|
||||
return "uint16";
|
||||
case ArrayDataType::kInt32:
|
||||
return "Int32";
|
||||
return "int32";
|
||||
case ArrayDataType::kUint32:
|
||||
return "Uint32";
|
||||
return "uint32";
|
||||
case ArrayDataType::kInt64:
|
||||
return "Int64";
|
||||
return "int64";
|
||||
case ArrayDataType::kUint64:
|
||||
return "Uint64";
|
||||
return "uint64";
|
||||
case ArrayDataType::kString:
|
||||
return "String";
|
||||
return "string";
|
||||
case ArrayDataType::kBool:
|
||||
return "Bool";
|
||||
return "bool";
|
||||
case ArrayDataType::kComplex64:
|
||||
return "Complex64";
|
||||
return "complex64";
|
||||
case ArrayDataType::kNone:
|
||||
return "None";
|
||||
default:
|
||||
@ -538,7 +538,8 @@ void DumpGraphvizVideoFrame(const Model& model) {
|
||||
static int dump_id = 0;
|
||||
static std::unordered_set<std::size_t> dump_hashes;
|
||||
string graphviz_dump;
|
||||
DumpGraphviz(model, &graphviz_dump);
|
||||
DumpGraphviz(model, &graphviz_dump,
|
||||
toco::port::StringF("VIDEO frame:%05d", dump_id));
|
||||
std::size_t hash = std::hash<string>{}(graphviz_dump);
|
||||
if (!dump_hashes.count(hash)) {
|
||||
LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
|
||||
@ -561,7 +562,7 @@ void LogDump(int log_level, const string& message, const Model& model) {
|
||||
if (!dump_options.dump_graphviz.empty()) {
|
||||
string graphviz_dump;
|
||||
|
||||
DumpGraphviz(model, &graphviz_dump);
|
||||
DumpGraphviz(model, &graphviz_dump, message);
|
||||
const auto result = port::file::SetContents(
|
||||
port::file::JoinPath(
|
||||
dump_options.dump_graphviz,
|
||||
@ -1863,119 +1864,140 @@ string CreateInt32Array(Model* model, const string& param_name,
|
||||
return param_array_name;
|
||||
}
|
||||
|
||||
bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
|
||||
int64* result) {
|
||||
switch (op.type) {
|
||||
case OperatorType::kFullyConnected:
|
||||
case OperatorType::kConv:
|
||||
case OperatorType::kDepthwiseConv: {
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
const auto& weights_array = model.GetArray(op.inputs[1]);
|
||||
if (!output_array.has_shape() || !weights_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
int64 cols = 1;
|
||||
for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
|
||||
cols *= output_array.shape().dims(i);
|
||||
}
|
||||
const int64 cost_per_col =
|
||||
2 * RequiredBufferSizeForShape(weights_array.shape());
|
||||
*result = cost_per_col * cols;
|
||||
if (op.inputs.size() > 2) {
|
||||
// There is a bias vector. One more op per output value.
|
||||
*result += RequiredBufferSizeForShape(output_array.shape());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OperatorType::kAdd:
|
||||
case OperatorType::kSub:
|
||||
case OperatorType::kMul: {
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
*result = RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
case OperatorType::kAddN: {
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// AddN cost is roughly the same cost as N-1 Adds.
|
||||
const int64 num_adds = op.inputs.size() - 1;
|
||||
*result = num_adds * RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
case OperatorType::kLogistic:
|
||||
case OperatorType::kSoftmax:
|
||||
case OperatorType::kLogSoftmax:
|
||||
case OperatorType::kTanh: {
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// As a very rough ballpark, the cost of evaluating a math function
|
||||
// such as tanh or logistic is about 32 multiplications, and about as
|
||||
// many additions/subtractions. (Just a power-of-two order-of-magnitude
|
||||
// from looking at actual implementations that we use in runtime/ code).
|
||||
*result = 64 * RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
case OperatorType::kMaxPool: {
|
||||
const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
*result = RequiredBufferSizeForShape(output_array.shape()) *
|
||||
maxpool.kheight * maxpool.kwidth;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kAveragePool: {
|
||||
const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
*result = RequiredBufferSizeForShape(output_array.shape()) *
|
||||
avgpool.kheight * avgpool.kwidth;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kL2Pool: {
|
||||
const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// The sum of squares requires (kheight*kwidth) multiply-adds,
|
||||
// and then there is the sqrt which we ballpark at 32 ops.
|
||||
const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
|
||||
*result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kL2Normalization: {
|
||||
const auto& output_array = model.GetArray(op.outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// Computing the squared L2 norm is N multiply-adds so 2N ops,
|
||||
// then the single inverse-sqrt is negligible, then we multiply each
|
||||
// value by the resulting multiplier, so an extra N ops. count 3N ops.
|
||||
*result = 3 * RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
*result = 0;
|
||||
break;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
|
||||
int64 total = 0;
|
||||
for (const auto& op : model.operators) {
|
||||
switch (op->type) {
|
||||
case OperatorType::kFullyConnected:
|
||||
case OperatorType::kConv:
|
||||
case OperatorType::kDepthwiseConv: {
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
const auto& weights_array = model.GetArray(op->inputs[1]);
|
||||
if (!output_array.has_shape() || !weights_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
int cols = 1;
|
||||
for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
|
||||
cols *= output_array.shape().dims(i);
|
||||
}
|
||||
const int64 cost_per_col =
|
||||
2 * RequiredBufferSizeForShape(weights_array.shape());
|
||||
total += cost_per_col * cols;
|
||||
if (op->inputs.size() > 2) {
|
||||
// There is a bias vector. One more op per output value.
|
||||
total += RequiredBufferSizeForShape(output_array.shape());
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OperatorType::kAdd:
|
||||
case OperatorType::kSub:
|
||||
case OperatorType::kMul: {
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
total += RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
case OperatorType::kAddN: {
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// AddN cost is roughly the same cost as N-1 Adds.
|
||||
const int num_adds = op->inputs.size() - 1;
|
||||
total += num_adds * RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
case OperatorType::kLogistic:
|
||||
case OperatorType::kSoftmax:
|
||||
case OperatorType::kLogSoftmax:
|
||||
case OperatorType::kTanh: {
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// As a very rough ballpark, the cost of evaluating a math function
|
||||
// such as tanh or logistic is about 32 multiplications, and about as
|
||||
// many additions/subtractions. (Just a power-of-two order-of-magnitude
|
||||
// from looking at actual implementations that we use in runtime/ code).
|
||||
total += 64 * RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
case OperatorType::kMaxPool: {
|
||||
const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get());
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
total += RequiredBufferSizeForShape(output_array.shape()) *
|
||||
maxpool.kheight * maxpool.kwidth;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kAveragePool: {
|
||||
const auto& avgpool =
|
||||
*static_cast<const AveragePoolOperator*>(op.get());
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
total += RequiredBufferSizeForShape(output_array.shape()) *
|
||||
avgpool.kheight * avgpool.kwidth;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kL2Pool: {
|
||||
const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get());
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// The sum of squares requires (kheight*kwidth) multiply-adds,
|
||||
// and then there is the sqrt which we ballpark at 32 ops.
|
||||
const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
|
||||
total +=
|
||||
RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
|
||||
break;
|
||||
}
|
||||
case OperatorType::kL2Normalization: {
|
||||
const auto& output_array = model.GetArray(op->outputs[0]);
|
||||
if (!output_array.has_shape()) {
|
||||
return false;
|
||||
}
|
||||
// Computing the squared L2 norm is N multiply-adds so 2N ops,
|
||||
// then the single inverse-sqrt is negligible, then we multiply each
|
||||
// value by the resulting multiplier, so an extra N ops. Total 3N ops.
|
||||
total += 3 * RequiredBufferSizeForShape(output_array.shape());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
int64 num_ops;
|
||||
if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
|
||||
return false;
|
||||
}
|
||||
total += num_ops;
|
||||
}
|
||||
*result = total;
|
||||
return true;
|
||||
}
|
||||
|
||||
string FormattedNumber(int64 x) {
|
||||
const int64 million = 1000000;
|
||||
const int64 billion = 1000000000;
|
||||
if (x < 10000) {
|
||||
return toco::port::StringF("%d ", x);
|
||||
} else if (x < billion) {
|
||||
return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
|
||||
} else {
|
||||
return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
|
||||
}
|
||||
}
|
||||
|
||||
void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
|
||||
std::vector<int>* shuffle) {
|
||||
CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
|
||||
|
||||
@ -267,7 +267,10 @@ void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
|
||||
string CreateInt32Array(Model* model, const string& param_name,
|
||||
const std::vector<int>& value);
|
||||
|
||||
bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
|
||||
int64* result);
|
||||
bool EstimateArithmeticOpsCount(const Model& model, int64* result);
|
||||
string FormattedNumber(int64 x);
|
||||
|
||||
int AxesCount(AxesOrder axes_order);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user