Multiple enhancements to graphviz output:

- nodes organized by subgraphs inferred from array names
  - subgraphs and ops include math op count estimates
  - better organized labels on all nodes
  - operators have numbered inputs and outputs when two or more exist
  - specific node shapes for inputs/outputs
  - node color scheme unchanged

PiperOrigin-RevId: 233446311
This commit is contained in:
A. Unique TensorFlower 2019-02-11 11:34:45 -08:00 committed by TensorFlower Gardener
parent 734e10d80b
commit 08703e1aad
6 changed files with 761 additions and 333 deletions

View File

@ -378,6 +378,7 @@ cc_library(
":types_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/lite/kernels/internal:types",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
"@protobuf_archive//:protobuf_headers",

View File

@ -15,17 +15,21 @@ limitations under the License.
#include "tensorflow/lite/toco/dump_graphviz.h"
#include <cmath>
#include <functional>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "re2/re2.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
#include "tensorflow/lite/toco/toco_port.h"
#include "tensorflow/lite/toco/toco_types.h"
#include "tensorflow/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
using toco::port::AppendF;
using toco::port::StringF;
@ -33,72 +37,158 @@ using toco::port::StringF;
namespace toco {
namespace {
// 'nslimit' is a graphviz (dot) paramater that limits the iterations during
// the layout phase. Omitting it allows infinite iterations, causing some
// complex graphs to never finish. A value of 125 produces good graphs
// while allowing complex graphs to finish.
constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
)CODE";
// Note: tooltip's are only supported on SVGs in Chrome.
constexpr char kSubgraphFmt[] =
R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
)CODE";
constexpr char kArrayNodeFmt[] =
R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
)CODE";
constexpr char kOpNodeFmt[] =
R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
)CODE";
constexpr char kInputEdgeFmt[] =
R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
)CODE";
constexpr char kOutputEdgeFmt[] =
R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
)CODE";
constexpr char kRNNBackEdgeFmt[] =
R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false];
)CODE";
constexpr char kUnicodeMult[] = "\u00D7";
constexpr char kUnicodeEllipsis[] = " \u2026 ";
class Color {
public:
Color() {}
Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
explicit Color(uint32 word)
: r_((word & 0x00FF0000) >> 16),
g_((word & 0x0000FF00) >> 8),
b_((word & 0x000000FF) >> 0) {}
// Returns the string serialization of this color in graphviz format,
// for use as 'fillcolor' in boxes.
string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); }
string AsHexString() const { return StringF("#%.2X%.2X%.2X", r_, g_, b_); }
// The color to use for this node; will be used as 'fillcolor'
// for its box. See Color::AsHexString. A suitable, different
// color will be chosen for the 'fontcolor' for the inside text
// label, see Color::TextColorString.
// Returns the serialization in graphviz format of a suitable color to use
// 'fontcolor' in the same boxes. It should black or white, whichever offers
// the better contrast from FillColorString().
// the better contrast from AsHexString().
string TextColorString() const {
// https://en.wikipedia.org/wiki/Relative_luminance
const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
const uint8 l = luminance > 128.f ? 0 : 255;
return StringF("%.2X%.2X%.2X", l, l, l);
return StringF("#%.2X%.2X%.2X", l, l, l);
}
private:
uint8 r_ = 0, g_ = 0, b_ = 0;
};
struct NodeProperties {
// The text to display inside the box for this node.
string label;
// The color to use for this node; will be used as 'fillcolor'
// for its box. See Color::FillColorString. A suitable, different
// color will be chosen for the 'fontcolor' for the inside text
// label, see Color::TextColorString.
Color color;
float log2_buffer_size;
};
Color HashStringToColor(string s) {
// Return a unique color for a name.
//
// This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
// the string to a uint_32, then twiddles some bits to get a light and subtle
// color. This seems to be a good heuristic for keeping enough of the name to
// hash to a unique color while still revealing structure through naming
// similarities.
//
// The regular expression "_\d+" matches any underscore followed by numbers,
// which we strip out. Examples:
//
// "Conv" -> "Conv"
// "Conv_2" -> "Conv"
// "Conv_72" -> "Conv"
// "Pad_1_bias -> "Pad_bias"
// "Conv_abc" -> "Conv_abc"
// All colors in this file are from:
// https://material.io/guidelines/style/color.html
RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
uint32 color_word = std::hash<std::string>{}(s);
color_word |= 0x00E0E0E0;
return Color(color_word);
}
Color GetColorForArray(const Model& model, const string& array_name) {
void GetArrayColorAndShape(const Model& model, const string& array_name,
Color* color, string* shape) {
// All colors in this file are from:
// https://material.io/guidelines/style/color.html
// Arrays involved in RNN back-edges have a different color
for (const auto& rnn_state : model.flags.rnn_states()) {
// RNN state, fed by a back-edge. Bold color.
if (array_name == rnn_state.state_array()) {
return Color(0x0F, 0x9D, 0x58);
*color = Color(0x0F, 0x9D, 0x58);
*shape = "invhouse";
return;
}
// RNN back-edge source, feeding a RNN state.
// Light tone of the same color as RNN states.
if (array_name == rnn_state.back_edge_source_array()) {
return Color(0xB7, 0xE1, 0xCD);
*color = Color(0xB7, 0xE1, 0xCD);
*shape = "house";
return;
}
}
// Constant parameter arrays have their own bold color
if (model.GetArray(array_name).buffer) {
return Color(0x42, 0x85, 0xF4);
*color = Color(0x42, 0x85, 0xF4);
*shape = "cylinder";
return;
}
// Remaining arrays are activations.
// We use gray colors for them because they are the majority
// of arrays so we want to highlight other arrays instead of them.
// First, we use a bolder gray for input/output arrays:
if (IsInputArray(model, array_name)) {
return Color(0x9E, 0x9E, 0x9E);
*color = Color(0x9E, 0x9E, 0x9E);
*shape = "invhouse";
return;
}
if (IsOutputArray(model, array_name)) {
return Color(0x9E, 0x9E, 0x9E);
*color = Color(0x9E, 0x9E, 0x9E);
*shape = "house";
return;
}
// Remaining arrays are intermediate activation arrays.
// Lighter tone of the same grey as for input/output arrays:
// We want these to be very discrete.
return Color(0xF5, 0xF5, 0xF5);
*color = Color(0xF5, 0xF5, 0xF5);
*shape = "box";
}
string GetArrayCompassPt(const Model& model, const string& array_name) {
// The "compass point" is the point on the node where edge connections are
// made. For most arrays we don't care, but input's and outputs look better
// connected at the tip of the "house" and "invhouse" shapes used. So we
// append ":n" and ":s" respectively for those.
for (const auto& rnn_state : model.flags.rnn_states()) {
// RNN state is essentially an input
if (array_name == rnn_state.state_array()) {
return ":s";
}
// RNN back-edge source is essentially an output
if (array_name == rnn_state.back_edge_source_array()) {
return ":n";
}
}
if (IsInputArray(model, array_name)) {
return ":s";
}
if (IsOutputArray(model, array_name)) {
return ":n";
}
return "";
}
void AppendArrayVal(string* string, Array const& array, int index) {
@ -141,239 +231,550 @@ void AppendArrayVal(string* string, Array const& array, int index) {
}
}
NodeProperties GetPropertiesForArray(const Model& model,
const string& array_name) {
NodeProperties node_properties;
node_properties.color = GetColorForArray(model, array_name);
node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}});
node_properties.log2_buffer_size = 0.0f;
typedef std::map<string, string> Attributes;
// Append array shape to the label.
auto& array = model.GetArray(array_name);
AppendF(&node_properties.label, "\\nType: %s",
ArrayDataTypeName(array.data_type));
if (array.has_shape()) {
auto& array_shape = array.shape();
node_properties.label += "\\n[";
for (int id = 0; id < array_shape.dimensions_count(); id++) {
if (id == 0) {
AppendF(&node_properties.label, "%d", array_shape.dims(id));
} else {
// 0x00D7 is the unicode multiplication symbol
AppendF(&node_properties.label, "\u00D7%d", array_shape.dims(id));
}
}
node_properties.label += "]";
int buffer_size = 0;
if (IsNonEmpty(array.shape())) {
buffer_size = RequiredBufferSizeForShape(array.shape());
node_properties.log2_buffer_size =
std::log2(static_cast<float>(buffer_size));
}
if (array.buffer) {
const auto& array = model.GetArray(array_name);
if (buffer_size <= 4) {
AppendF(&node_properties.label, " = ");
if (array.shape().dimensions_count() > 0) {
AppendF(&node_properties.label, "{");
}
for (int i = 0; i < buffer_size; i++) {
AppendArrayVal(&node_properties.label, array, i);
if (i + 1 < buffer_size) {
AppendF(&node_properties.label, ", ");
}
}
} else {
AppendF(&node_properties.label, "\\n = ");
if (array.shape().dimensions_count() > 0) {
AppendF(&node_properties.label, "{");
}
AppendArrayVal(&node_properties.label, array, 0);
AppendF(&node_properties.label, ", ");
AppendArrayVal(&node_properties.label, array, 1);
// 0x2026 is the unicode ellipsis symbol
AppendF(&node_properties.label, " \u2026 ");
AppendArrayVal(&node_properties.label, array, buffer_size - 2);
AppendF(&node_properties.label, ", ");
AppendArrayVal(&node_properties.label, array, buffer_size - 1);
}
if (array.shape().dimensions_count() > 0) {
AppendF(&node_properties.label, "}");
}
}
string AttributesToHtml(Attributes attributes) {
string html;
for (const auto& attr : attributes) {
html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
html += attr.first;
html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
html += attr.second;
html += "</TD></TR>";
}
if (array.minmax) {
AppendF(&node_properties.label, "\\nMinMax: [%.7g, %.7g]",
array.minmax->min, array.minmax->max);
}
if (array.quantization_params) {
AppendF(&node_properties.label, "\\nQuantization: %7g * (x - %d)",
array.quantization_params->scale,
array.quantization_params->zero_point);
}
if (array.alloc) {
AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)",
array.alloc->start, array.alloc->end);
}
return node_properties;
return html;
}
NodeProperties GetPropertiesForOperator(const Operator& op) {
NodeProperties node_properties;
if (op.type == OperatorType::kUnsupported) {
node_properties.label =
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
} else {
node_properties.label =
string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
string GetArrayLabel(const Model& model, const string& array_id) {
string html;
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
html += "<";
// Begin Table
html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
auto& array = model.GetArray(array_id);
if (array.buffer) {
// "cylinder" shapes require some extra head room.
html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
}
// "Primary" name of array (last non-slash delimited group of characters).
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
AppendF(&html, R"CODE(%s)CODE",
std::vector<string>(absl::StrSplit(array_id, '/')).back());
html += R"CODE(</I></FONT>)CODE";
html += "</TD></TR>";
// Array data type and dimensions
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
// Type
html += ArrayDataTypeName(array.data_type);
// Shape
if (array.has_shape()) {
auto& array_shape = array.shape();
html += "[";
for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
AppendF(&html, "%d", array_shape.dims(dim));
if (dim + 1 < array_shape.dimensions_count()) {
html += kUnicodeMult;
}
}
html += "]";
}
// Small buffer sample
int buffer_size = 0;
if (array.buffer) {
buffer_size = RequiredBufferSizeForShape(array.shape());
}
if ((buffer_size > 0) && (buffer_size <= 4)) {
html += " = ";
if (array.shape().dimensions_count() > 0) {
html += "{";
}
for (int i = 0; i < buffer_size; i++) {
AppendArrayVal(&html, array, i);
if (i + 1 < buffer_size) {
html += ", ";
}
}
if (array.shape().dimensions_count() > 0) {
html += "}";
}
}
html += R"CODE(</B></FONT>)CODE";
html += "</TD></TR>";
// Large buffer samples get their own line
if (buffer_size > 4) {
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
AppendArrayVal(&html, array, 0);
html += ", ";
AppendArrayVal(&html, array, 1);
html += kUnicodeEllipsis;
AppendArrayVal(&html, array, buffer_size - 2);
html += ", ";
AppendArrayVal(&html, array, buffer_size - 1);
html += "}</TD></TR>";
}
// Other array properties
Attributes attrs;
if (array.minmax) {
attrs["minmax"] =
StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
}
if (array.quantization_params) {
attrs["quant"] = StringF("%7g\u00B7(x-%d)", // Unicode "cdot"
array.quantization_params->scale,
array.quantization_params->zero_point);
}
if (array.alloc) {
attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
}
html += AttributesToHtml(attrs);
// output array_id in ultra-small font so it can be searched and copied.
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
AppendF(&html, R"CODE("%s")CODE", array_id);
html += R"CODE(</FONT>)CODE";
html += "</TD></TR>";
// End Table and HTML-like label
html += R"CODE(</TABLE></FONT>)CODE";
html += ">";
return html;
}
Attributes GetOpAttributes(const Model& model, const Operator& op) {
Attributes attrs;
switch (op.fused_activation_function) {
case FusedActivationFunctionType::kRelu:
AppendF(&node_properties.label, "\\nReLU");
attrs["func"] = "ReLU";
break;
case FusedActivationFunctionType::kRelu6:
AppendF(&node_properties.label, "\\nReLU6");
attrs["func"] = "ReLU6";
break;
case FusedActivationFunctionType::kRelu1:
AppendF(&node_properties.label, "\\nReLU1");
attrs["func"] = "ReLU1";
break;
default:
break;
}
// Additional information for some of the operators.
// Output state of member vars on derived operators.
switch (op.type) {
case OperatorType::kConv: {
const auto& conv_op = static_cast<const ConvOperator&>(op);
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
conv_op.stride_height,
conv_op.padding.type == PaddingType::kSame ? "S" : "V");
string stride;
AppendF(&stride, "%d", conv_op.stride_width);
stride += kUnicodeMult;
AppendF(&stride, "%d", conv_op.stride_height);
attrs["stride"] = stride;
attrs["padding"] =
(conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
break;
}
case OperatorType::kDepthwiseConv: {
const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width,
conv_op.stride_height,
conv_op.padding.type == PaddingType::kSame ? "S" : "V");
break;
}
case OperatorType::kFullyConnected: {
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
const auto& depthconv_op = static_cast<const ConvOperator&>(op);
string stride;
AppendF(&stride, "%d", depthconv_op.stride_width);
stride += kUnicodeMult;
AppendF(&stride, "%d", depthconv_op.stride_height);
attrs["stride"] = stride;
attrs["padding"] =
(depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
break;
}
case OperatorType::kFakeQuant: {
const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
attrs["bits"] = StringF("%d", fakequant_op.num_bits);
if (fakequant_op.minmax) {
AppendF(&node_properties.label, "\\n%dbit [%g,%g]",
fakequant_op.num_bits, fakequant_op.minmax->min,
fakequant_op.minmax->max);
attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
fakequant_op.minmax->max);
} else {
AppendF(&node_properties.label, "\\n%dbit [?,?]",
fakequant_op.num_bits);
attrs["range"] = "[?,?]";
}
break;
}
default:
node_properties.color = Color(0xDB, 0x44, 0x37);
break;
}
int64 math_ops_count;
if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
(math_ops_count != 0)) {
attrs["math"] = FormattedNumber(math_ops_count) + "ops";
}
return node_properties;
return attrs;
}
} // namespace
void DumpGraphviz(const Model& model, string* output_file_contents) {
AppendF(output_file_contents, "digraph Computegraph {\n");
// 'nslimit' is a graphviz (dot) paramater that limits the iterations during
// the layout phase. Omitting it allows infinite iterations, causing some
// complex graphs to never finish. A value of 125 produces good graphs
// while allowing complex graphs to finish.
AppendF(output_file_contents, "\t nslimit=125;\n");
constexpr char kNodeFormat[] =
"\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", "
"fontcolor = \"#%sDD\"];\n";
constexpr char kEdgeFormat[] =
"\t \"%s\" -> \"%s\" [penwidth=%f, weight=%f];\n";
constexpr char kRNNBackEdgeFormat[] =
"\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n";
for (const auto& array_kv : model.GetArrayMap()) {
// Add node for array.
const string& array_name = array_kv.first;
const auto& array_properties = GetPropertiesForArray(model, array_name);
AppendF(output_file_contents, kNodeFormat, array_name,
array_properties.label, "octagon",
array_properties.color.FillColorString().c_str(),
array_properties.color.TextColorString().c_str());
Color GetOpColor(const Operator& op) {
if ((op.type == OperatorType::kDepthwiseConv) ||
(op.type == OperatorType::kConv) ||
(op.type == OperatorType::kFullyConnected) ||
(op.type == OperatorType::kFakeQuant)) {
// Give some ops a bolder red
return Color(0xC5, 0x39, 0x29);
} else {
return Color(0xDB, 0x44, 0x37);
}
}
string GetOpLabel(const Model& model, const Operator& op) {
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
string html;
html += "<";
// Begin Table
html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
html +=
R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
// Input Ports
if (!op.inputs.empty()) {
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
// Distribute evenly using a sub-table
html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
html += R"CODE(<TR>)CODE";
for (int i = 0; i < op.inputs.size(); i++) {
html += R"CODE(<TD PORT=")CODE";
AppendF(&html, "i%d", i);
html += R"CODE(">)CODE";
if (op.inputs.size() > 1) {
// Only number inputs when op has two or more inputs
AppendF(&html, "%d", i);
}
html += "</TD>";
}
html += "</TR>";
html += R"CODE(</TABLE></TD></TR>)CODE";
}
// Name
html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
if (op.type == OperatorType::kUnsupported) {
html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
} else {
html += string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
}
html += R"CODE(</B></FONT>)CODE";
html += "</TD></TR>";
// Attributes
Attributes attrs = GetOpAttributes(model, op);
html += AttributesToHtml(attrs);
// Output Ports
if (!op.outputs.empty()) {
html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
// Distribute evenly using a sub-table
html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
html += R"CODE(<TR>)CODE";
for (int i = 0; i < op.outputs.size(); i++) {
html += R"CODE(<TD PORT=")CODE";
AppendF(&html, "o%d", i);
html += R"CODE(">)CODE";
if (op.outputs.size() > 1) {
// Only number outputs when op has two or more outputs
AppendF(&html, "%d", i);
}
html += "</TD>";
}
html += "</TR>";
html += R"CODE(</TABLE></TD></TR>)CODE";
}
// End Table and HTML-like label
html += R"CODE(</TABLE></FONT>)CODE";
html += ">";
return html;
}
float GetLog2BufferSize(const Model& model, const string& array_id) {
auto& array = model.GetArray(array_id);
if (array.has_shape()) {
int buffer_size = 0;
if (IsNonEmpty(array.shape())) {
buffer_size = RequiredBufferSizeForShape(array.shape());
return std::log2(static_cast<float>(buffer_size));
}
}
return 0.0f;
}
string GetOpId(int op_index) { return StringF("op%05d", op_index); }
void DumpOperator(const Model& model, string* output_file, int op_index) {
// Dump node for operator.
const Operator& op = *model.operators[op_index];
Color color = GetOpColor(op);
string label = GetOpLabel(model, op);
string op_id = GetOpId(op_index);
AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
color.TextColorString());
}
void DumpOperatorEdges(const Model& model, string* output_file, int op_index) {
// Inputs
const Operator& op = *model.operators[op_index];
string op_id = GetOpId(op_index);
for (int i = 0; i < op.inputs.size(); i++) {
const auto& input = op.inputs[i];
if (!model.HasArray(input)) {
// Connected arrays should _always_ exist. Except, perhaps, during
// development.
continue;
}
float log2_buffer_size = GetLog2BufferSize(model, input);
// Draw lines that transport more data thicker (Otherwise, where would the
// data fit? right?).
float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
// Keep edges that transport more data shorter than those with less.
float weight = std::max(1.0f, log2_buffer_size);
if (!IsInputArray(model, input) &&
GetOpWithOutput(model, input) == nullptr) {
// Give the main line of data flow a straighter path by penalizing edges
// to standalone buffers. Weights are generally very large buffers that
// would otherwise skew the layout.
weight = 1.0f;
}
string compass_pt = GetArrayCompassPt(model, input);
AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
weight);
}
// Outputs
for (int i = 0; i < op.outputs.size(); i++) {
const auto& output = op.outputs[i];
if (!model.HasArray(output)) {
continue;
}
float log2_buffer_size = GetLog2BufferSize(model, output);
// See comments above regarding weight and line_width calculations.
float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
float weight = std::max(1.0f, log2_buffer_size);
if (!IsArrayConsumed(model, output)) {
weight = 1.0f;
}
string compass_pt = GetArrayCompassPt(model, output);
AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
line_width, weight);
}
}
struct Node {
Node() : math_ops(0) {}
// Name used as a key in the model's array map
string array_id;
// Estimated number of math ops incurred by this node (the sum of the op
// with this array as 1st output, plus all children nodes).
int64 math_ops;
// A map of child nodes keyed by name.
std::map<const string, std::unique_ptr<Node>> children;
};
string GetSubgraphLabel(Node const& node, const string& subgraph) {
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
string html;
html += "<";
// Begin Table
html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
html +=
R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
// Name
html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
html += subgraph;
html += R"CODE(</I></FONT>)CODE";
html += "</TD></TR>";
// Attributes
Attributes attrs;
if (node.math_ops > 0) {
attrs["math"] = FormattedNumber(node.math_ops) + "ops";
}
html += AttributesToHtml(attrs);
// End Table and HTML-like label
html += R"CODE(</TABLE></FONT>)CODE";
html += ">";
return html;
}
void DumpSubgraphHeader(string* output_file, Node const& node,
const string& node_name) {
Color color = HashStringToColor(node_name);
string label = GetSubgraphLabel(node, node_name);
AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
}
void DumpArray(const Model& model, string* output_file,
const string& array_id) {
Color color;
string shape;
GetArrayColorAndShape(model, array_id, &color, &shape);
string label = GetArrayLabel(model, array_id);
AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
color.AsHexString(), color.TextColorString());
// Ops are placed in the same subgraph as their first output.
for (int op_index = 0; op_index < model.operators.size(); op_index++) {
const Operator& op = *model.operators[op_index];
// Add node for operator.
auto op_properties = GetPropertiesForOperator(op);
string operator_id = StringF("op%05d", op_index);
AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label,
"box", op_properties.color.FillColorString().c_str(),
op_properties.color.TextColorString().c_str());
// Add edges for all inputs of the operator.
for (const auto& input : op.inputs) {
if (!model.HasArray(input)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
auto array_properties = GetPropertiesForArray(model, input);
// Draw lines that transport more data thicker (Otherwise, where would the
// data fit? right?).
float line_width =
std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
// Keep edges that transport more data shorter than those with less.
float weight = std::max(1.0f, array_properties.log2_buffer_size);
if (!IsInputArray(model, input) &&
GetOpWithOutput(model, input) == nullptr) {
// Give the main line of data flow a straighter path by penalizing edges
// to standalone buffers. Weights are generally very large buffers that
// otherwise skew the layout without this.
weight = 1.0f;
}
AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width,
weight);
}
// Add edges for all outputs of the operator.
for (const auto& output : op.outputs) {
if (!model.HasArray(output)) {
// Arrays should _always_ exist. Except, perhaps, during development.
continue;
}
auto array_properties = GetPropertiesForArray(model, output);
// See comments above regarding weight and line_width calculations.
float line_width =
std::max(0.5f, array_properties.log2_buffer_size / 3.0f);
float weight = std::max(1.0f, array_properties.log2_buffer_size);
if (!IsArrayConsumed(model, output)) {
weight = 1.0f;
}
AppendF(output_file_contents, kEdgeFormat, operator_id, output,
line_width, weight);
if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
DumpOperator(model, output_file, op_index);
}
}
}
void DumpNode(const Model& model, string* output_file, const string& node_name,
Node const& node) {
bool not_root = !node_name.empty();
if (not_root) {
DumpSubgraphHeader(output_file, node, node_name);
}
for (const auto& child : node.children) {
if (!child.second->array_id.empty()) {
// Dump array if this node posesses one.
DumpArray(model, output_file, child.second->array_id);
}
// Note that it is always possible to have children. Unlike a filesystem,
// the existence of array "foo/bar" does _not_ prevent other arrays, such as
// and "foo/bar/baz", from being nested beneath it.
DumpNode(model, output_file, child.first, *child.second);
}
if (not_root) {
// End subgraph
AppendF(output_file, " }\n");
}
}
int64 GetArithmeticOpsCount(const Model& model, const string& array_id) {
for (const auto& op : model.operators) {
if (!op->outputs.empty() && op->outputs[0] == array_id) {
int64 count;
if (EstimateArithmeticOpsCount(model, *op, &count)) {
return count;
} else {
return 0;
}
}
}
return 0;
}
void InsertNode(const Model& model, const string& array_id, Node* node,
std::vector<string> prefixes, int64* math_ops) {
if (prefixes.empty()) {
// Base case: store array in this node.
node->array_id = array_id;
*math_ops = GetArithmeticOpsCount(model, array_id);
} else {
// Insert into the sub-tree for that prefix.
string prefix = prefixes.back();
prefixes.pop_back();
if (node->children.count(prefix) == 0) {
// Create a new node if this prefix is unseen.
node->children[prefix] = absl::make_unique<Node>();
}
InsertNode(model, array_id, node->children[prefix].get(), prefixes,
math_ops);
}
// Sum estimated math ops into all nodes.
node->math_ops += *math_ops;
}
void BuildArrayTree(const Model& model, Node* tree) {
// Delimit array names by path "/", then place into a tree based on this path.
for (const auto& array_id : model.GetArrayMap()) {
std::vector<string> prefixes = absl::StrSplit(array_id.first, '/');
std::reverse(prefixes.begin(), prefixes.end());
int64 math_ops; // Temporary storage for math ops used during recursion.
InsertNode(model, array_id.first, tree, prefixes, &math_ops);
}
}
string GetGraphLabel(const Model& model, const string& graph_name) {
// Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
string html;
html += "<";
// Begin Table
html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
html +=
R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
// Name
html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
html += graph_name;
html += R"CODE(</I></B></FONT>)CODE";
html += "</TD></TR>";
// Attributes
Attributes attrs;
attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
if (!model.optional_arrays.empty()) {
attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
}
attrs["operators"] = StringF("%d", model.operators.size());
int64 ops_count;
if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
attrs["math"] = FormattedNumber(ops_count) + "ops";
}
if (model.transient_data_size > 0) {
attrs["transient data size"] =
StringF("%d KiB", model.transient_data_size / 1024);
}
if (model.transient_data_alignment > 0) {
attrs["transient data alignment"] =
StringF("%d bytes", model.transient_data_alignment);
}
html += AttributesToHtml(attrs);
// End Table and HTML-like label
html += R"CODE(</TABLE></FONT>)CODE";
html += ">";
return html;
}
} // namespace
void DumpGraphviz(const Model& model, string* output_file,
const string& graph_name) {
// Start graphviz format
AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
// Organize arrays into a tree for subgraphing
Node tree;
BuildArrayTree(model, &tree);
DumpNode(model, output_file, "", tree);
// Dump edges outside all subgraphs (otherwise the referred-to nodes are
// implicitly included in that subgraph).
for (int op_index = 0; op_index < model.operators.size(); op_index++) {
DumpOperatorEdges(model, output_file, op_index);
}
// Dump RNN Backedges
for (const auto& rnn_state : model.flags.rnn_states()) {
AppendF(output_file_contents, kRNNBackEdgeFormat,
rnn_state.back_edge_source_array(), rnn_state.state_array());
AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
rnn_state.state_array());
}
AppendF(output_file_contents, "}\n");
// End graphviz format
AppendF(output_file, "}\n");
}
} // namespace toco

View File

@ -21,7 +21,8 @@ limitations under the License.
namespace toco {
void DumpGraphviz(const Model& model, string* output_file_contents);
void DumpGraphviz(const Model& model, string* output_file_contents,
const string& graph_name);
} // namespace toco

View File

@ -454,7 +454,7 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
return status;
} break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
DumpGraphviz(model, output_file_contents, "Computation Graph");
break;
default:
LOG(FATAL) << "Unhandled output_format='"

View File

@ -66,29 +66,29 @@ string LogName(const Operator& op) {
string ArrayDataTypeName(ArrayDataType data_type) {
switch (data_type) {
case ArrayDataType::kFloat:
return "Float";
return "float";
case ArrayDataType::kInt8:
return "Int8";
return "int8";
case ArrayDataType::kUint8:
return "Uint8";
return "uint8";
case ArrayDataType::kInt16:
return "Int16";
return "int16";
case ArrayDataType::kUint16:
return "Uint16";
return "uint16";
case ArrayDataType::kInt32:
return "Int32";
return "int32";
case ArrayDataType::kUint32:
return "Uint32";
return "uint32";
case ArrayDataType::kInt64:
return "Int64";
return "int64";
case ArrayDataType::kUint64:
return "Uint64";
return "uint64";
case ArrayDataType::kString:
return "String";
return "string";
case ArrayDataType::kBool:
return "Bool";
return "bool";
case ArrayDataType::kComplex64:
return "Complex64";
return "complex64";
case ArrayDataType::kNone:
return "None";
default:
@ -538,7 +538,8 @@ void DumpGraphvizVideoFrame(const Model& model) {
static int dump_id = 0;
static std::unordered_set<std::size_t> dump_hashes;
string graphviz_dump;
DumpGraphviz(model, &graphviz_dump);
DumpGraphviz(model, &graphviz_dump,
toco::port::StringF("VIDEO frame:%05d", dump_id));
std::size_t hash = std::hash<string>{}(graphviz_dump);
if (!dump_hashes.count(hash)) {
LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
@ -561,7 +562,7 @@ void LogDump(int log_level, const string& message, const Model& model) {
if (!dump_options.dump_graphviz.empty()) {
string graphviz_dump;
DumpGraphviz(model, &graphviz_dump);
DumpGraphviz(model, &graphviz_dump, message);
const auto result = port::file::SetContents(
port::file::JoinPath(
dump_options.dump_graphviz,
@ -1863,119 +1864,140 @@ string CreateInt32Array(Model* model, const string& param_name,
return param_array_name;
}
bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
int64* result) {
switch (op.type) {
case OperatorType::kFullyConnected:
case OperatorType::kConv:
case OperatorType::kDepthwiseConv: {
const auto& output_array = model.GetArray(op.outputs[0]);
const auto& weights_array = model.GetArray(op.inputs[1]);
if (!output_array.has_shape() || !weights_array.has_shape()) {
return false;
}
int64 cols = 1;
for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
cols *= output_array.shape().dims(i);
}
const int64 cost_per_col =
2 * RequiredBufferSizeForShape(weights_array.shape());
*result = cost_per_col * cols;
if (op.inputs.size() > 2) {
// There is a bias vector. One more op per output value.
*result += RequiredBufferSizeForShape(output_array.shape());
}
break;
}
case OperatorType::kAdd:
case OperatorType::kSub:
case OperatorType::kMul: {
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
*result = RequiredBufferSizeForShape(output_array.shape());
break;
}
case OperatorType::kAddN: {
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// AddN cost is roughly the same cost as N-1 Adds.
const int64 num_adds = op.inputs.size() - 1;
*result = num_adds * RequiredBufferSizeForShape(output_array.shape());
break;
}
case OperatorType::kLogistic:
case OperatorType::kSoftmax:
case OperatorType::kLogSoftmax:
case OperatorType::kTanh: {
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// As a very rough ballpark, the cost of evaluating a math function
// such as tanh or logistic is about 32 multiplications, and about as
// many additions/subtractions. (Just a power-of-two order-of-magnitude
// from looking at actual implementations that we use in runtime/ code).
*result = 64 * RequiredBufferSizeForShape(output_array.shape());
break;
}
case OperatorType::kMaxPool: {
const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
*result = RequiredBufferSizeForShape(output_array.shape()) *
maxpool.kheight * maxpool.kwidth;
break;
}
case OperatorType::kAveragePool: {
const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
*result = RequiredBufferSizeForShape(output_array.shape()) *
avgpool.kheight * avgpool.kwidth;
break;
}
case OperatorType::kL2Pool: {
const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// The sum of squares requires (kheight*kwidth) multiply-adds,
// and then there is the sqrt which we ballpark at 32 ops.
const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
*result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
break;
}
case OperatorType::kL2Normalization: {
const auto& output_array = model.GetArray(op.outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// Computing the squared L2 norm is N multiply-adds so 2N ops,
// then the single inverse-sqrt is negligible, then we multiply each
// value by the resulting multiplier, so an extra N ops. count 3N ops.
*result = 3 * RequiredBufferSizeForShape(output_array.shape());
break;
}
default:
*result = 0;
break;
}
return true;
}
bool EstimateArithmeticOpsCount(const Model& model, int64* result) {
int64 total = 0;
for (const auto& op : model.operators) {
switch (op->type) {
case OperatorType::kFullyConnected:
case OperatorType::kConv:
case OperatorType::kDepthwiseConv: {
const auto& output_array = model.GetArray(op->outputs[0]);
const auto& weights_array = model.GetArray(op->inputs[1]);
if (!output_array.has_shape() || !weights_array.has_shape()) {
return false;
}
int cols = 1;
for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
cols *= output_array.shape().dims(i);
}
const int64 cost_per_col =
2 * RequiredBufferSizeForShape(weights_array.shape());
total += cost_per_col * cols;
if (op->inputs.size() > 2) {
// There is a bias vector. One more op per output value.
total += RequiredBufferSizeForShape(output_array.shape());
}
break;
}
case OperatorType::kAdd:
case OperatorType::kSub:
case OperatorType::kMul: {
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
total += RequiredBufferSizeForShape(output_array.shape());
break;
}
case OperatorType::kAddN: {
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// AddN cost is roughly the same cost as N-1 Adds.
const int num_adds = op->inputs.size() - 1;
total += num_adds * RequiredBufferSizeForShape(output_array.shape());
break;
}
case OperatorType::kLogistic:
case OperatorType::kSoftmax:
case OperatorType::kLogSoftmax:
case OperatorType::kTanh: {
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// As a very rough ballpark, the cost of evaluating a math function
// such as tanh or logistic is about 32 multiplications, and about as
// many additions/subtractions. (Just a power-of-two order-of-magnitude
// from looking at actual implementations that we use in runtime/ code).
total += 64 * RequiredBufferSizeForShape(output_array.shape());
break;
}
case OperatorType::kMaxPool: {
const auto& maxpool = *static_cast<const MaxPoolOperator*>(op.get());
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
total += RequiredBufferSizeForShape(output_array.shape()) *
maxpool.kheight * maxpool.kwidth;
break;
}
case OperatorType::kAveragePool: {
const auto& avgpool =
*static_cast<const AveragePoolOperator*>(op.get());
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
total += RequiredBufferSizeForShape(output_array.shape()) *
avgpool.kheight * avgpool.kwidth;
break;
}
case OperatorType::kL2Pool: {
const auto* maxpool = static_cast<const MaxPoolOperator*>(op.get());
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// The sum of squares requires (kheight*kwidth) multiply-adds,
// and then there is the sqrt which we ballpark at 32 ops.
const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
total +=
RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
break;
}
case OperatorType::kL2Normalization: {
const auto& output_array = model.GetArray(op->outputs[0]);
if (!output_array.has_shape()) {
return false;
}
// Computing the squared L2 norm is N multiply-adds so 2N ops,
// then the single inverse-sqrt is negligible, then we multiply each
// value by the resulting multiplier, so an extra N ops. Total 3N ops.
total += 3 * RequiredBufferSizeForShape(output_array.shape());
break;
}
default:
break;
int64 num_ops;
if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
return false;
}
total += num_ops;
}
*result = total;
return true;
}
string FormattedNumber(int64 x) {
const int64 million = 1000000;
const int64 billion = 1000000000;
if (x < 10000) {
return toco::port::StringF("%d ", x);
} else if (x < billion) {
return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
} else {
return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
}
}
void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
std::vector<int>* shuffle) {
CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));

View File

@ -267,7 +267,10 @@ void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
string CreateInt32Array(Model* model, const string& param_name,
const std::vector<int>& value);
bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
int64* result);
bool EstimateArithmeticOpsCount(const Model& model, int64* result);
string FormattedNumber(int64 x);
int AxesCount(AxesOrder axes_order);