Optimize eight-bit graphs by removing RequantizationRanges
Change: 144145086
This commit is contained in:
parent
5d9a05a8b0
commit
b42ba8aace
@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -60,7 +60,9 @@ cc_library(
|
||||
"fold_batch_norms.cc",
|
||||
"fold_constants_lib.cc",
|
||||
"fold_old_batch_norms.cc",
|
||||
"freeze_requantization_ranges.cc",
|
||||
"fuse_convolutions.cc",
|
||||
"insert_logging.cc",
|
||||
"obsfucate_names.cc",
|
||||
"quantize_nodes.cc",
|
||||
"quantize_weights.cc",
|
||||
@ -101,7 +103,9 @@ tf_cc_test(
|
||||
"fold_batch_norms_test.cc",
|
||||
"fold_constants_test.cc",
|
||||
"fold_old_batch_norms_test.cc",
|
||||
"freeze_requantization_ranges_test.cc",
|
||||
"fuse_convolutions_test.cc",
|
||||
"insert_logging_test.cc",
|
||||
"obsfucate_names_test.cc",
|
||||
"quantize_nodes_test.cc",
|
||||
"quantize_weights_test.cc",
|
||||
|
@ -16,7 +16,9 @@
|
||||
* [fold_batch_norms](#fold_batch_norms)
|
||||
* [fold_constants](#fold_constants)
|
||||
* [fold_old_batch_norms](#fold_old_batch_norms)
|
||||
* [freeze_requantization_ranges](#freeze_requantization_ranges)
|
||||
* [fuse_convolutions](#fuse_convolutions)
|
||||
* [insert_logging](#insert_logging)
|
||||
* [merge_duplicate_nodes](#merge_duplicate_nodes)
|
||||
* [obsfucate_names](#obsfucate_names)
|
||||
* [quantize_nodes](#quantize_nodes)
|
||||
@ -334,7 +336,7 @@ within the saved model, and sets them to the defined default for that attribute.
|
||||
|
||||
### fold_batch_norms
|
||||
|
||||
Args: None
|
||||
Args: None \
|
||||
Prerequisites: [fold_constants](#fold_constants)
|
||||
|
||||
This transform tries to optimize away the Mul that's introduced after a Conv2D
|
||||
@ -347,7 +349,7 @@ produced by training for the Mul input is collapsed down into a simple constant.
|
||||
|
||||
### fold_constants
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
Looks for any sub-graphs within the model that always evaluate to constant
|
||||
@ -359,7 +361,7 @@ to continue on past transient errors, since this is just an optimization phase.
|
||||
|
||||
### fold_old_batch_norms
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
In the early days of TensorFlow, batch normalization was implemented using a
|
||||
@ -370,21 +372,143 @@ have a graph that uses the older-style, this transform will recognize and
|
||||
optimize those ops for inference, in the same way that the
|
||||
[fold_batch_norms](#fold_batch_norms) transform does for the new approach.
|
||||
|
||||
### freeze_requantization_ranges
|
||||
|
||||
Args:
|
||||
|
||||
* min_max_log_file: Path to a log file containing ranges for ops.
|
||||
* min_percentile: Percentage cutoff to use to calculate an overall min.
|
||||
Defaults to 5.
|
||||
* max_percentile: Percentage cutoff to use to calculate an overall max.
|
||||
Defaults to 5.
|
||||
|
||||
Quantized operations like convolution or matrix multiplies take their inputs as
|
||||
8-bit, but produce 32-bit results. To do further operations on these, they need
|
||||
to be converted back down to the lower depth. To make the most of those eight
|
||||
bits, you need to scale the thirty-two bits of original data down using a scale
|
||||
that matches the range that's actually being used.
|
||||
|
||||
Because that range information isn't stored in the original graph, the
|
||||
[quantization process](#eight-bit-calculations) inserts RequantizationRange ops
|
||||
before each conversion from 32 to 8 bits. This op looks at the 32-bit output and
|
||||
calculates the current min and max every time it's run.
|
||||
|
||||
This isn't incredibly time-consuming, but it is extra work that's nice to avoid
|
||||
if possible. One way of optimizing that away is replacing those
|
||||
RequantizationRange ops with a pair of Const nodes holding known min/max values,
|
||||
so the scaling down can be done without having to inspect the output every time.
|
||||
|
||||
That's what this transform does. It's usually used in conjunction with a copy of
|
||||
the graph that's had [insert_logging](#insert_logging) run on it to instrument
|
||||
it to record the min/max values to stderr. Why is logging used rather than
|
||||
writing to a normal file? As you'll see later, to get best results you want to
|
||||
collect data from a lot of runs on real data, and for mobile apps especially
|
||||
it's a lot easier to do this by copying log files. As an example, here's how
|
||||
you'd add the logging operations for a quantized version of the Inception v3
|
||||
graph:
|
||||
|
||||
```bash
|
||||
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||
--logtostderr \
|
||||
--in_graph=/tmp/quantized_inception.pb \
|
||||
--out_graph=/tmp/logged_quantized_inception.pb \
|
||||
--inputs=Mul \
|
||||
--outputs=softmax \
|
||||
--transforms='\
|
||||
insert_logging(op=RequantizationRange, show_name=true, message="__requant_min_max:")\
|
||||
'
|
||||
```
|
||||
|
||||
Now, when you run the `/tmp/logged_quantized_inception.pb` graph, it will write
|
||||
out log statements that show the value of the min and max calculated by each
|
||||
RequantizationRange op. Here's an example of running label_image and saving the
|
||||
log:
|
||||
|
||||
```bash
|
||||
bazel build tensorflow/examples/label_image:label_image
|
||||
bazel-bin/tensorflow/examples/label_image/label_image \
|
||||
--image=${HOME}/Downloads/grace_hopper.jpg \
|
||||
--logtostderr \
|
||||
--input_layer=Mul \
|
||||
--output_layer=softmax \
|
||||
--graph=/tmp/logged_quantized_inception.pb \
|
||||
--labels=learning/brain/models/image/inception_v3/imagenet_comp_graph_label_strings.txt \
|
||||
--logtostderr \
|
||||
2>/tmp/min_max_log_small.txt
|
||||
```
|
||||
|
||||
If you look in `/tmp/min_max_log_small.txt`, you'll see a lot of lines like
|
||||
this:
|
||||
|
||||
```
|
||||
I0108 21:45:42.261883 1972 logging_ops.cc:79] ;conv/Conv2D/eightbit/requant_range__print__;__requant_min_max:[-20.887871][22.274715]
|
||||
```
|
||||
|
||||
This is a simple way of serializing the name of the RequantizationRange op and
|
||||
its min/max values every time it's run. It's a file like this that you pass into
|
||||
the transform as the `min_max_log_file` argument. The transform will attempt to
|
||||
extract all of the min/max values associated with ops, ignoring any irrelevant
|
||||
lines in the log, and replace the RequantizationRange ops with two Const nodes
|
||||
containing the found values.
|
||||
|
||||
This isn't the whole story though. The min/max values can vary a lot depending
|
||||
on what the particular inputs to the graph are on any given run, which means
|
||||
picking ranges based on just one run can lead to clipping of values and a loss
|
||||
of accuracy. To get better results, you need to run your network against a range
|
||||
of different inputs. In Inception's case, I often use a thousand different
|
||||
images from the training set. You can then pass the whole concatenated log from
|
||||
all of the runs into the transform, and it will pick ranges based on the
|
||||
aggregate of the values found for each RequantizationRange op.
|
||||
|
||||
To ensure that outliers don't increase the range too much, and so decrease the
|
||||
accuracy by putting too many bits into rare extreme values, the `min_percentile`
|
||||
and `max_percentile` arguments control how the overall min and max are chosen.
|
||||
At their default values of 5, this means that the lowest 5% of the minimum
|
||||
values will be discarded, taking the minimum of the remainder, and the
|
||||
equivalent for the maximum.
|
||||
|
||||
### fuse_convolutions
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g.
|
||||
to scale up in the later stages of an image style transfer model),
|
||||
it can improve memory usage and latency to combine the spatial
|
||||
transformations with the convolution's im2col patch generation. This transform
|
||||
looks out for that particular pattern of ops and replaces them with a fused
|
||||
version that combines the resizing and padding with the convolution.
|
||||
For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g. to
|
||||
scale up in the later stages of an image style transfer model), it can improve
|
||||
memory usage and latency to combine the spatial transformations with the
|
||||
convolution's im2col patch generation. This transform looks out for that
|
||||
particular pattern of ops and replaces them with a fused version that combines
|
||||
the resizing and padding with the convolution.
|
||||
|
||||
### insert_logging
|
||||
|
||||
Args:
|
||||
|
||||
* op: Insert a Print after every occurrence of this op type. Can be repeated
|
||||
to cover multiple types. If not present, all op types will be instrumented.
|
||||
* prefix: Insert a Print after every node whose name starts with this value.
|
||||
Can be repeated to cover multiple nodes. If not present, all node names will
|
||||
be matched.
|
||||
* show_op: If true, the op type will be prepended to all log messages.
|
||||
* show_name: If true, the node's name will be prepended to all log messages.
|
||||
* message: Arbitrary text to log before the values.
|
||||
* first_n: How many times to print before suppressing. Defaults to -1, which
|
||||
means never stop.
|
||||
* summarize: How long numerical results can be before they're truncated.
|
||||
Defaults to 1024.
|
||||
|
||||
The Print operator writes strings to stderr when it's run inside a graph, and
|
||||
prints out the numerical results of the node that it's reading from. This can be
|
||||
very useful when you're debugging and want to follow particular internal values
|
||||
while a graph is running. This transform allows you to insert those ops at
|
||||
particular points in the graph, and customize the message that's displayed. It's
|
||||
also used in conjunction with the
|
||||
[freeze_requantization_ranges](#freeze_requantization_ranges) transform to
|
||||
output information that it needs.
|
||||
|
||||
### merge_duplicate_nodes
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
If there are Const nodes with the same types and contents, or nodes with the
|
||||
@ -396,7 +520,7 @@ duplicates of constants that are used in the quantize/dequantize process).
|
||||
|
||||
### obsfucate_names
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
Replaces all nodes' names with short generated ids, other than the inputs and
|
||||
@ -429,14 +553,14 @@ Prerequisites: [quantize_weights](#quantize_weights)
|
||||
Replaces any calculation nodes with their eight-bit equivalents (if available),
|
||||
and adds in conversion layers to allow remaining float operations to
|
||||
interoperate. This is one of the most complex transforms, and involves multiple
|
||||
passes and a lot of rewriting. It's also still an active area of research,
|
||||
so results may vary depending on the platform and operations you're using in
|
||||
your model. You should run quantize_weights first to ensure your Const ops are
|
||||
in eight-bit form.
|
||||
passes and a lot of rewriting. It's also still an active area of research, so
|
||||
results may vary depending on the platform and operations you're using in your
|
||||
model. You should run quantize_weights first to ensure your Const ops are in
|
||||
eight-bit form.
|
||||
|
||||
### quantize_weights
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
Converts any large (more than 15 element) float Const op into an eight-bit
|
||||
@ -461,7 +585,7 @@ special circumstances though.
|
||||
|
||||
### remove_device
|
||||
|
||||
Args: None
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
All ops can have a hardware device specified. This can be a problem when you're
|
||||
@ -548,7 +672,7 @@ device assigned.
|
||||
|
||||
### sort_by_execution_order
|
||||
|
||||
Args: None\
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
Arranges the nodes in the GraphDef in topological order, so that the inputs of
|
||||
@ -741,8 +865,8 @@ transform:
|
||||
This is looking for QuantizeV2 nodes, with three inputs, the first of which is a
|
||||
Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the
|
||||
third is a Max which does the same. Assuming we know the Dequantize ops are
|
||||
pulling from the same eight-bit buffer, the end result of this sub-graph is
|
||||
a no-op, since it's just turning the eight-bit buffer into float, and then
|
||||
pulling from the same eight-bit buffer, the end result of this sub-graph is a
|
||||
no-op, since it's just turning the eight-bit buffer into float, and then
|
||||
immediately converting it back to eight-bits, so if we look for this pattern and
|
||||
remove it we can optimize the graph without changing the result.
|
||||
|
||||
@ -874,7 +998,7 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps`
|
||||
parameter:
|
||||
|
||||
```C++
|
||||
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
|
||||
TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("num_steps", 256, &num_steps));
|
||||
```
|
||||
|
||||
If the conversion fails or the parameter occurs more than once the helper
|
||||
|
@ -0,0 +1,213 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
struct MinMaxRecord {
|
||||
string name;
|
||||
float min;
|
||||
float max;
|
||||
};
|
||||
|
||||
// Try to parse a log file containing loosely-structured lines, some of which
|
||||
// are the min/max logs we want.
|
||||
Status ExtractMinMaxRecords(const string& log_file_name,
|
||||
std::vector<MinMaxRecord>* records) {
|
||||
string file_data;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadFileToString(Env::Default(), log_file_name, &file_data));
|
||||
const string print_suffix("__print__");
|
||||
const string requant_prefix("__requant_min_max:");
|
||||
std::vector<string> file_lines = str_util::Split(file_data, '\n');
|
||||
for (const string& file_line : file_lines) {
|
||||
// We expect to find a line with components separated by semicolons, so to
|
||||
// start make sure that the basic structure is in place/
|
||||
StringPiece line(file_line);
|
||||
if (!line.contains(print_suffix + ";" + requant_prefix)) {
|
||||
continue;
|
||||
}
|
||||
std::vector<string> line_parts = str_util::Split(file_line, ';');
|
||||
if (line_parts.size() < 2) {
|
||||
continue;
|
||||
}
|
||||
// Now we want to figure out which components have the name and min max
|
||||
// values by scanning for the prefix we expect.
|
||||
bool min_max_found = false;
|
||||
int min_max_index;
|
||||
for (int i = 1; i < line_parts.size(); ++i) {
|
||||
StringPiece line_part(line_parts[i]);
|
||||
if (line_part.starts_with(requant_prefix)) {
|
||||
min_max_found = true;
|
||||
min_max_index = i;
|
||||
}
|
||||
}
|
||||
if (!min_max_found) {
|
||||
continue;
|
||||
}
|
||||
// Finally we need to break out the values from the strings, and parse them
|
||||
// into a form we can use.
|
||||
string min_max_string = line_parts[min_max_index];
|
||||
std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
|
||||
if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
|
||||
continue;
|
||||
}
|
||||
string min_string = min_max_parts[1];
|
||||
std::vector<string> min_string_parts = str_util::Split(min_string, ']');
|
||||
if (min_string_parts.size() != 2) {
|
||||
continue;
|
||||
}
|
||||
string min_number_string = min_string_parts[0];
|
||||
float min;
|
||||
if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
|
||||
continue;
|
||||
}
|
||||
string max_string = min_max_parts[2];
|
||||
std::vector<string> max_string_parts = str_util::Split(max_string, ']');
|
||||
if (max_string_parts.size() != 2) {
|
||||
continue;
|
||||
}
|
||||
string max_number_string = max_string_parts[0];
|
||||
float max;
|
||||
if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
|
||||
continue;
|
||||
}
|
||||
StringPiece name_string = line_parts[min_max_index - 1];
|
||||
if (!name_string.ends_with(print_suffix)) {
|
||||
continue;
|
||||
}
|
||||
string name =
|
||||
name_string.substr(0, name_string.size() - print_suffix.size())
|
||||
.ToString();
|
||||
records->push_back({name, min, max});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Uses the observed min/max values for requantization captured in a log file to
|
||||
// replace costly RequantizationRange ops with simple Consts.
|
||||
Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
string min_max_log_file;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
|
||||
if (min_max_log_file == "") {
|
||||
return errors::InvalidArgument(
|
||||
"You must pass a file name to min_max_log_file");
|
||||
}
|
||||
float min_percentile;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
|
||||
float max_percentile;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
|
||||
|
||||
std::vector<MinMaxRecord> records;
|
||||
TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
|
||||
if (records.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"No min/max range logs were found in the log file");
|
||||
}
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(input_graph_def, &node_map);
|
||||
bool any_missing_nodes = false;
|
||||
std::map<string, std::vector<MinMaxRecord>> records_by_node;
|
||||
for (const MinMaxRecord& record : records) {
|
||||
records_by_node[record.name].push_back(record);
|
||||
if (!node_map.count(record.name)) {
|
||||
any_missing_nodes = true;
|
||||
LOG(WARNING) << "Node from log not found in graph: " << record.name;
|
||||
}
|
||||
}
|
||||
if (any_missing_nodes) {
|
||||
return errors::InvalidArgument(
|
||||
"Nodes were found in the log file that aren't present in the graph");
|
||||
}
|
||||
|
||||
// Now find out the largest and smallest min/max values for the node.
|
||||
std::map<string, std::pair<float, float>> range_for_nodes;
|
||||
for (const auto& record_info : records_by_node) {
|
||||
const string& name = record_info.first;
|
||||
const std::vector<MinMaxRecord> records = record_info.second;
|
||||
std::vector<float> mins;
|
||||
std::vector<float> maxs;
|
||||
for (const MinMaxRecord& record : records) {
|
||||
mins.push_back(record.min);
|
||||
maxs.push_back(record.max);
|
||||
}
|
||||
std::sort(mins.begin(), mins.end());
|
||||
std::sort(maxs.begin(), maxs.end());
|
||||
int min_index = std::round(mins.size() * (min_percentile / 100.0f));
|
||||
if (min_index < 0) {
|
||||
min_index = 0;
|
||||
}
|
||||
int max_index =
|
||||
std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
|
||||
if (max_index > (maxs.size() - 1)) {
|
||||
max_index = maxs.size() - 1;
|
||||
}
|
||||
const float min = mins[min_index];
|
||||
const float max = maxs[max_index];
|
||||
range_for_nodes[name] = {min, max};
|
||||
}
|
||||
std::map<string, string> inputs_to_rename;
|
||||
GraphDef frozen_graph_def;
|
||||
for (const NodeDef& node : input_graph_def.node()) {
|
||||
if (range_for_nodes.count(node.name())) {
|
||||
if (node.op() != "RequantizationRange") {
|
||||
return errors::InvalidArgument(
|
||||
"Node is expected to be a RequantizationRange op: ", node.name(),
|
||||
", but is: ", node.op());
|
||||
}
|
||||
const float min_value = range_for_nodes.at(node.name()).first;
|
||||
NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
|
||||
min_node->set_op("Const");
|
||||
min_node->set_name(node.name() + "/frozen_min");
|
||||
SetNodeAttr("dtype", DT_FLOAT, min_node);
|
||||
Tensor min_tensor(DT_FLOAT, {});
|
||||
min_tensor.flat<float>()(0) = min_value;
|
||||
SetNodeTensorAttr<float>("value", min_tensor, min_node);
|
||||
inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
|
||||
|
||||
const float max_value = range_for_nodes.at(node.name()).second;
|
||||
NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
|
||||
max_node->set_op("Const");
|
||||
max_node->set_name(node.name() + "/frozen_max");
|
||||
SetNodeAttr("dtype", DT_FLOAT, max_node);
|
||||
Tensor max_tensor(DT_FLOAT, {});
|
||||
max_tensor.flat<float>()(0) = max_value;
|
||||
SetNodeTensorAttr<float>("value", max_tensor, max_node);
|
||||
inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
|
||||
} else {
|
||||
NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
|
||||
new_node->CopyFrom(node);
|
||||
}
|
||||
}
|
||||
RenameNodeInputs(frozen_graph_def, inputs_to_rename,
|
||||
std::unordered_set<string>(), output_graph_def);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
|
||||
FreezeRequantizationRanges);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -0,0 +1,200 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
struct MinMaxRecord {
|
||||
string name;
|
||||
float min;
|
||||
float max;
|
||||
};
|
||||
Status ExtractMinMaxRecords(const string& log_file_name,
|
||||
std::vector<MinMaxRecord>* records);
|
||||
|
||||
class FreezeRequantizationRangesTest : public ::testing::Test {
|
||||
protected:
|
||||
void TestFreezeRequantizationRanges() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
|
||||
test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
|
||||
Output quantized_op = Const(root.WithOpName("quantized_op"),
|
||||
Input::Initializer(quantized_tensor));
|
||||
|
||||
Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&quantized_min_tensor, {2.0f});
|
||||
Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
|
||||
Input::Initializer(quantized_min_tensor));
|
||||
|
||||
Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&quantized_max_tensor, {2.0f});
|
||||
Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
|
||||
Input::Initializer(quantized_min_tensor));
|
||||
|
||||
Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
|
||||
test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
|
||||
Output offset_op =
|
||||
Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
|
||||
|
||||
Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&offset_min_tensor, {0.0f});
|
||||
Output offset_min_op = Const(root.WithOpName("offset_min_op"),
|
||||
Input::Initializer(offset_min_tensor));
|
||||
|
||||
Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&offset_max_tensor, {255.0f});
|
||||
Output offset_max_op = Const(root.WithOpName("offset_max_op"),
|
||||
Input::Initializer(offset_max_tensor));
|
||||
|
||||
QuantizedBiasAdd quantized_bias_add_op(
|
||||
root.WithOpName("bias_add_op"), quantized_op, offset_op,
|
||||
quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
|
||||
DT_QINT32);
|
||||
|
||||
RequantizationRange requantization_range_op(
|
||||
root.WithOpName("requantization_range_op"),
|
||||
quantized_bias_add_op.output, quantized_bias_add_op.min_out,
|
||||
quantized_bias_add_op.max_out);
|
||||
|
||||
Requantize requantize_op(
|
||||
root.WithOpName("requantize_op"), quantized_bias_add_op.output,
|
||||
quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
|
||||
requantization_range_op.output_min, requantization_range_op.output_max,
|
||||
DT_QUINT8);
|
||||
|
||||
Output dequantize_op =
|
||||
Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
|
||||
requantize_op.output_min, requantize_op.output_max);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
const string min_max_log_file_name =
|
||||
io::JoinPath(testing::TmpDir(), "min_max_log_file.txt");
|
||||
{
|
||||
std::unique_ptr<WritableFile> file;
|
||||
TF_ASSERT_OK(
|
||||
Env::Default()->NewWritableFile(min_max_log_file_name, &file));
|
||||
TF_ASSERT_OK(file->Append("Something irrelevant\n"));
|
||||
TF_ASSERT_OK(
|
||||
file->Append("[SomePrefix] "
|
||||
";requantization_range_op__print__;__requant_min_max:"
|
||||
"[-2.4313571][10.584145]\n"));
|
||||
TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
|
||||
}
|
||||
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"dequantize_op"};
|
||||
context.params = {{"min_max_log_file", {min_max_log_file_name}}};
|
||||
|
||||
GraphDef frozen_graph_def;
|
||||
TF_EXPECT_OK(
|
||||
FreezeRequantizationRanges(graph_def, context, &frozen_graph_def));
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(frozen_graph_def, &node_map);
|
||||
EXPECT_EQ(0, node_map.count("requantization_range_op"));
|
||||
EXPECT_EQ(1, node_map.count("requantize_op"));
|
||||
const string& min_input =
|
||||
NodeNameFromInput(node_map.at("requantize_op")->input(3));
|
||||
ASSERT_EQ(1, node_map.count(min_input));
|
||||
EXPECT_EQ("Const", node_map.at(min_input)->op());
|
||||
const string& max_input =
|
||||
NodeNameFromInput(node_map.at("requantize_op")->input(4));
|
||||
ASSERT_EQ(1, node_map.count(max_input));
|
||||
EXPECT_EQ("Const", node_map.at(max_input)->op());
|
||||
|
||||
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||
TF_ASSERT_OK(original_session->Create(graph_def));
|
||||
std::vector<Tensor> original_outputs;
|
||||
TF_ASSERT_OK(
|
||||
original_session->Run({}, {"dequantize_op"}, {}, &original_outputs));
|
||||
|
||||
std::unique_ptr<Session> frozen_session(NewSession(SessionOptions()));
|
||||
TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
|
||||
std::vector<Tensor> frozen_outputs;
|
||||
TF_ASSERT_OK(
|
||||
frozen_session->Run({}, {"dequantize_op"}, {}, &frozen_outputs));
|
||||
|
||||
ASSERT_EQ(original_outputs.size(), frozen_outputs.size());
|
||||
ASSERT_EQ(1, frozen_outputs.size());
|
||||
test::ExpectTensorNear<float>(original_outputs[0], frozen_outputs[0], 0.5);
|
||||
}
|
||||
|
||||
void TestExtractMinMaxRecords() {
|
||||
const string min_max_log_file_name =
|
||||
io::JoinPath(testing::TmpDir(), "min_max_log_file2.txt");
|
||||
{
|
||||
std::unique_ptr<WritableFile> file;
|
||||
TF_ASSERT_OK(
|
||||
Env::Default()->NewWritableFile(min_max_log_file_name, &file));
|
||||
TF_ASSERT_OK(file->Append("Something irrelevant\n"));
|
||||
TF_ASSERT_OK(
|
||||
file->Append("[SomePrefix] "
|
||||
";requantization_range_op__print__;__requant_min_max:"
|
||||
"[-2.4313571][10.584145]\n"));
|
||||
TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
|
||||
TF_ASSERT_OK(file->Append(
|
||||
"[SomeOtherPrefix] "
|
||||
";other_requantization_range_op__print__;__requant_min_max:"
|
||||
"[-1.0][2.0]\n"));
|
||||
TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
|
||||
TF_ASSERT_OK(
|
||||
file->Append("[SomePrefix] "
|
||||
";requantization_range_op__print__;__requant_min_max:"
|
||||
"[-1.bad][2.0]\n"));
|
||||
}
|
||||
std::vector<MinMaxRecord> records;
|
||||
TF_ASSERT_OK(ExtractMinMaxRecords(min_max_log_file_name, &records));
|
||||
ASSERT_EQ(2, records.size());
|
||||
EXPECT_EQ("requantization_range_op", records[0].name);
|
||||
EXPECT_NEAR(-2.4313571f, records[0].min, 1e-5f);
|
||||
EXPECT_NEAR(10.584145f, records[0].max, 1e-5f);
|
||||
EXPECT_EQ("other_requantization_range_op", records[1].name);
|
||||
EXPECT_NEAR(-1.0f, records[1].min, 1e-5f);
|
||||
EXPECT_NEAR(2.0f, records[1].max, 1e-5f);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FreezeRequantizationRangesTest, TestFreezeRequantizationRanges) {
|
||||
TestFreezeRequantizationRanges();
|
||||
}
|
||||
|
||||
TEST_F(FreezeRequantizationRangesTest, TestExtractMinMaxRecords) {
|
||||
TestExtractMinMaxRecords();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
153
tensorflow/tools/graph_transforms/insert_logging.cc
Normal file
153
tensorflow/tools/graph_transforms/insert_logging.cc
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Clears the device field of all ops in the graph.
|
||||
Status InsertLogging(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
std::unordered_set<string> ops;
|
||||
bool has_ops;
|
||||
if (context.params.count("op")) {
|
||||
has_ops = true;
|
||||
for (const string& op : context.params.at("op")) {
|
||||
ops.insert(op);
|
||||
}
|
||||
} else {
|
||||
has_ops = false;
|
||||
}
|
||||
|
||||
std::unordered_set<string> prefixes;
|
||||
bool has_prefixes;
|
||||
if (context.params.count("prefix")) {
|
||||
has_prefixes = true;
|
||||
for (const string& prefix : context.params.at("prefix")) {
|
||||
prefixes.insert(prefix);
|
||||
}
|
||||
} else {
|
||||
has_prefixes = false;
|
||||
}
|
||||
|
||||
string message;
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message));
|
||||
|
||||
bool show_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneBoolParameter("show_name", false, &show_name));
|
||||
|
||||
bool show_op;
|
||||
TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op));
|
||||
|
||||
int32 first_n;
|
||||
TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n));
|
||||
|
||||
int32 summarize;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneInt32Parameter("summarize", 1024, &summarize));
|
||||
|
||||
std::unordered_map<string, std::set<int>> node_outputs;
|
||||
for (const NodeDef& node : input_graph_def.node()) {
|
||||
for (const string& input : node.input()) {
|
||||
const string canonical_input = CanonicalInputName(input);
|
||||
string prefix;
|
||||
string name;
|
||||
string suffix;
|
||||
NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix);
|
||||
const string output_index_string = suffix.substr(1, suffix.size() - 1);
|
||||
int32 output_index;
|
||||
if (!strings::safe_strto32(output_index_string, &output_index)) {
|
||||
return errors::InvalidArgument("Couldn't understand output number in ",
|
||||
input);
|
||||
}
|
||||
node_outputs[name].insert(output_index);
|
||||
}
|
||||
}
|
||||
|
||||
std::map<string, string> inputs_to_rename;
|
||||
std::unordered_set<string> ignore_when_renaming;
|
||||
GraphDef logged_graph_def;
|
||||
for (const NodeDef& node : input_graph_def.node()) {
|
||||
NodeDef* new_node = logged_graph_def.mutable_node()->Add();
|
||||
new_node->CopyFrom(node);
|
||||
if (node_outputs[node.name()].empty()) {
|
||||
// There were no outputs found to this node, so skip it.
|
||||
continue;
|
||||
}
|
||||
const bool op_matches = (ops.count(node.op()) > 0);
|
||||
bool prefix_matches = false;
|
||||
for (const string& prefix : prefixes) {
|
||||
if (StringPiece(node.name()).starts_with(prefix)) {
|
||||
prefix_matches = true;
|
||||
}
|
||||
}
|
||||
// If we're not looking for ops, or we found the right op, and if we're not
|
||||
// looking for prefixes or we found the right prefix, then add logging here.
|
||||
if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) {
|
||||
const string name_suffix = "__print__";
|
||||
DataTypeVector input_types;
|
||||
DataTypeVector output_types;
|
||||
TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types));
|
||||
NodeDef* print_node = logged_graph_def.mutable_node()->Add();
|
||||
print_node->set_op("Print");
|
||||
print_node->set_name(strings::StrCat(node.name(), name_suffix));
|
||||
string node_message;
|
||||
if (show_op) {
|
||||
node_message += ";" + node.op() + ";";
|
||||
}
|
||||
if (show_name) {
|
||||
node_message += ";" + print_node->name() + ";";
|
||||
}
|
||||
node_message += message;
|
||||
SetNodeAttr("message", node_message, print_node);
|
||||
SetNodeAttr("first_n", first_n, print_node);
|
||||
SetNodeAttr("summarize", summarize, print_node);
|
||||
print_node->add_input(node.name() + ":0");
|
||||
SetNodeAttr("T", output_types[0], print_node);
|
||||
for (int output_index : node_outputs[node.name()]) {
|
||||
print_node->add_input(strings::StrCat(node.name(), ":", output_index));
|
||||
}
|
||||
SetNodeAttr("U", output_types, print_node);
|
||||
ignore_when_renaming.insert(print_node->name());
|
||||
// Rewrite the graph so all references to the first input of the original
|
||||
// op now pull from the print op instead, so it's executed.
|
||||
inputs_to_rename[node.name() + ":0"] =
|
||||
strings::StrCat(node.name(), name_suffix, ":0");
|
||||
}
|
||||
}
|
||||
|
||||
output_graph_def->Clear();
|
||||
RenameNodeInputs(logged_graph_def, inputs_to_rename, ignore_when_renaming,
|
||||
output_graph_def);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
203
tensorflow/tools/graph_transforms/insert_logging_test.cc
Normal file
203
tensorflow/tools/graph_transforms/insert_logging_test.cc
Normal file
@ -0,0 +1,203 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status InsertLogging(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
class InsertLoggingTest : public ::testing::Test {
|
||||
protected:
|
||||
void CheckGraphCanRun(const GraphDef& graph_def,
|
||||
const std::vector<string>& output_names) {
|
||||
std::unique_ptr<Session> session(NewSession(SessionOptions()));
|
||||
TF_ASSERT_OK(session->Create(graph_def));
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run({}, output_names, {}, &outputs));
|
||||
}
|
||||
|
||||
void TestInsertLogging() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Tensor const_tensor(DT_FLOAT, TensorShape({10}));
|
||||
test::FillIota<float>(&const_tensor, 1.0f);
|
||||
Output const_node1 =
|
||||
Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
|
||||
Output const_node2 =
|
||||
Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
|
||||
Output const_node3 =
|
||||
Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
|
||||
Output add_node2 =
|
||||
Add(root.WithOpName("add_node2"), const_node1, const_node2);
|
||||
Output add_node3 =
|
||||
Add(root.WithOpName("add_node3"), const_node1, const_node3);
|
||||
Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
|
||||
Output add_node4 =
|
||||
Add(root.WithOpName("add_node4"), mul_node1, const_node3);
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
CheckGraphCanRun(graph_def, {"add_node4"});
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"add_node4"};
|
||||
TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
|
||||
|
||||
CheckGraphCanRun(result, {"add_node4"});
|
||||
|
||||
std::unordered_set<string> print_inputs;
|
||||
for (const NodeDef& node : result.node()) {
|
||||
if (node.op() == "Print") {
|
||||
print_inputs.insert(node.input(0));
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(6, print_inputs.size());
|
||||
EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("add_node2:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("add_node3:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("add_node4:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("const_node1:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("const_node2:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("const_node3:0"));
|
||||
}
|
||||
|
||||
void TestInsertLoggingByOpType() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Tensor const_tensor(DT_FLOAT, TensorShape({10}));
|
||||
test::FillIota<float>(&const_tensor, 1.0f);
|
||||
Output const_node1 =
|
||||
Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
|
||||
Output const_node2 =
|
||||
Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
|
||||
Output const_node3 =
|
||||
Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
|
||||
Output add_node2 =
|
||||
Add(root.WithOpName("add_node2"), const_node1, const_node2);
|
||||
Output add_node3 =
|
||||
Add(root.WithOpName("add_node3"), const_node1, const_node3);
|
||||
Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
|
||||
Output add_node4 =
|
||||
Add(root.WithOpName("add_node4"), mul_node1, const_node3);
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
CheckGraphCanRun(graph_def, {"add_node4"});
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"add_node4"};
|
||||
context.params.insert(
|
||||
std::pair<string, std::vector<string>>({"op", {"Mul", "Add"}}));
|
||||
TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
|
||||
|
||||
CheckGraphCanRun(result, {"add_node4"});
|
||||
|
||||
std::unordered_set<string> print_inputs;
|
||||
for (const NodeDef& node : result.node()) {
|
||||
if (node.op() == "Print") {
|
||||
print_inputs.insert(node.input(0));
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(3, print_inputs.size());
|
||||
EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("add_node2:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("add_node3:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("add_node4:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("const_node1:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("const_node2:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("const_node3:0"));
|
||||
}
|
||||
|
||||
void TestInsertLoggingByPrefix() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
Tensor const_tensor(DT_FLOAT, TensorShape({10}));
|
||||
test::FillIota<float>(&const_tensor, 1.0f);
|
||||
Output const_node1 =
|
||||
Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
|
||||
Output const_node2 =
|
||||
Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
|
||||
Output const_node3 =
|
||||
Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
|
||||
Output add_node2 =
|
||||
Add(root.WithOpName("add_node2"), const_node1, const_node2);
|
||||
Output add_node3 =
|
||||
Add(root.WithOpName("add_node3"), const_node1, const_node3);
|
||||
Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
|
||||
Output add_node4 =
|
||||
Add(root.WithOpName("add_node4"), mul_node1, const_node3);
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
CheckGraphCanRun(graph_def, {"add_node4"});
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"add_node4"};
|
||||
context.params.insert(
|
||||
std::pair<string, std::vector<string>>({"prefix", {"add_node"}}));
|
||||
TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
|
||||
|
||||
CheckGraphCanRun(result, {"add_node4"});
|
||||
|
||||
std::unordered_set<string> print_inputs;
|
||||
for (const NodeDef& node : result.node()) {
|
||||
if (node.op() == "Print") {
|
||||
print_inputs.insert(node.input(0));
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(2, print_inputs.size());
|
||||
EXPECT_EQ(0, print_inputs.count("mul_node1:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("add_node2:0"));
|
||||
EXPECT_EQ(1, print_inputs.count("add_node3:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("add_node4:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("const_node1:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("const_node2:0"));
|
||||
EXPECT_EQ(0, print_inputs.count("const_node3:0"));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(InsertLoggingTest, TestInsertLogging) { TestInsertLogging(); }
|
||||
|
||||
TEST_F(InsertLoggingTest, TestInsertLoggingByOpType) {
|
||||
TestInsertLoggingByOpType();
|
||||
}
|
||||
|
||||
TEST_F(InsertLoggingTest, TestInsertLoggingByPrefix) {
|
||||
TestInsertLoggingByPrefix();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -236,7 +236,8 @@ Status MergeDuplicateNodes(const GraphDef& input_graph_def,
|
||||
}
|
||||
// Update the graph so that any nodes that referred to removed inputs now
|
||||
// pull from the remaining duplicate.
|
||||
RenameNodeInputs(merged_graph_def, inputs_to_rename, ¤t_graph_def);
|
||||
RenameNodeInputs(merged_graph_def, inputs_to_rename,
|
||||
std::unordered_set<string>(), ¤t_graph_def);
|
||||
} while (any_duplicates_found);
|
||||
|
||||
*output_graph_def = current_graph_def;
|
||||
@ -295,7 +296,8 @@ Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
|
||||
},
|
||||
{true}, &replaced_graph_def));
|
||||
|
||||
RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def);
|
||||
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
||||
std::unordered_set<string>(), output_graph_def);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -372,9 +374,9 @@ Status QuantizePlaceholders(const GraphDef& input_graph_def,
|
||||
|
||||
GraphDef first_pass_graph_def;
|
||||
RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
|
||||
&first_pass_graph_def);
|
||||
std::unordered_set<string>(), &first_pass_graph_def);
|
||||
RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
|
||||
output_graph_def);
|
||||
std::unordered_set<string>(), output_graph_def);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
||||
{true}, &replaced_graph_def));
|
||||
// Make sure all references to removed nodes now point to their inputs.
|
||||
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
||||
¤t_graph_def);
|
||||
std::unordered_set<string>(), ¤t_graph_def);
|
||||
} while (any_nodes_removed);
|
||||
}
|
||||
|
||||
|
@ -33,8 +33,9 @@ namespace graph_transforms {
|
||||
Status RoundWeights(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
int64 num_steps;
|
||||
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
|
||||
int32 num_steps;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneInt32Parameter("num_steps", 256, &num_steps));
|
||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||
input_graph_def, {"Const"},
|
||||
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||
|
@ -469,6 +469,7 @@ Status ReplaceMatchingOpTypes(
|
||||
|
||||
Status RenameNodeInputs(const GraphDef& input_graph_def,
|
||||
const std::map<string, string>& inputs_to_rename,
|
||||
const std::unordered_set<string>& nodes_to_ignore,
|
||||
GraphDef* output_graph_def) {
|
||||
std::map<string, std::vector<std::pair<string, string>>>
|
||||
canonical_inputs_to_rename;
|
||||
@ -494,6 +495,9 @@ Status RenameNodeInputs(const GraphDef& input_graph_def,
|
||||
input_node_name);
|
||||
}
|
||||
already_visited.insert(input_node_name);
|
||||
if (nodes_to_ignore.count(node.name())) {
|
||||
break;
|
||||
}
|
||||
bool any_match_found = false;
|
||||
for (const std::pair<string, string>& input_to_rename :
|
||||
canonical_inputs_to_rename.at(input_node_name)) {
|
||||
@ -630,7 +634,24 @@ Status TransformFuncContext::GetOneStringParameter(const string& name,
|
||||
}
|
||||
}
|
||||
|
||||
Status TransformFuncContext::GetOneIntParameter(const string& name,
|
||||
Status TransformFuncContext::GetOneInt32Parameter(const string& name,
|
||||
int32 default_value,
|
||||
int32* result) const {
|
||||
const int params_count = CountParameters(name);
|
||||
if (params_count == 0) {
|
||||
*result = default_value;
|
||||
return Status::OK();
|
||||
}
|
||||
string string_value;
|
||||
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||
if (!strings::safe_strto32(StringPiece(string_value), result)) {
|
||||
return errors::InvalidArgument("Couldn't interpret the ", name,
|
||||
" argument as a number:", string_value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformFuncContext::GetOneInt64Parameter(const string& name,
|
||||
int64 default_value,
|
||||
int64* result) const {
|
||||
const int params_count = CountParameters(name);
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
|
||||
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -201,9 +203,11 @@ Status ReplaceMatchingOpTypes(
|
||||
// Returns a list of the unique nodes found in this match.
|
||||
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
|
||||
|
||||
// Changes all input references to a particular node name.
|
||||
// Changes all input references to a particular node name. Any nodes with names
|
||||
// listed in nodes_to_ignore will not have their inputs rewritten.
|
||||
Status RenameNodeInputs(const GraphDef& input_graph_def,
|
||||
const std::map<string, string>& inputs_to_rename,
|
||||
const std::unordered_set<string>& nodes_to_ignore,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
// Utility function that copies all the nodes found in a match into the
|
||||
@ -225,10 +229,16 @@ struct TransformFuncContext {
|
||||
Status GetOneStringParameter(const string& name, const string& default_value,
|
||||
string* result) const;
|
||||
|
||||
// Gets a single occurrence of a parameter as an integer, falling back to a
|
||||
// default if it isn't present and returning an error if it isn't convertible
|
||||
// to a number.
|
||||
Status GetOneIntParameter(const string& name, int64 default_value,
|
||||
// Gets a single occurrence of a parameter as a 32-bit integer, falling back
|
||||
// to a default if it isn't present and returning an error if it isn't
|
||||
// convertible to a number.
|
||||
Status GetOneInt32Parameter(const string& name, int32 default_value,
|
||||
int32* result) const;
|
||||
|
||||
// Gets a single occurrence of a parameter as a 64-bit integer, falling back
|
||||
// to a default if it isn't present and returning an error if it isn't
|
||||
// convertible to a number.
|
||||
Status GetOneInt64Parameter(const string& name, int64 default_value,
|
||||
int64* result) const;
|
||||
|
||||
// Gets a single occurrence of a parameter as a floating point number, falling
|
||||
|
@ -541,7 +541,9 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
GraphDef renamed_graph_def;
|
||||
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def));
|
||||
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}},
|
||||
std::unordered_set<string>(),
|
||||
&renamed_graph_def));
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(renamed_graph_def, &node_map);
|
||||
@ -579,7 +581,7 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
GraphDef renamed_graph_def;
|
||||
TF_ASSERT_OK(RenameNodeInputs(
|
||||
graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
|
||||
&renamed_graph_def));
|
||||
std::unordered_set<string>(), &renamed_graph_def));
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(renamed_graph_def, &node_map);
|
||||
@ -615,8 +617,9 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
GraphDef renamed_graph_def;
|
||||
Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
|
||||
&renamed_graph_def);
|
||||
Status rename_status =
|
||||
RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
|
||||
std::unordered_set<string>(), &renamed_graph_def);
|
||||
EXPECT_FALSE(rename_status.ok());
|
||||
}
|
||||
|
||||
@ -650,6 +653,7 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
|
||||
GraphDef renamed_graph_def;
|
||||
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
|
||||
std::unordered_set<string>(),
|
||||
&renamed_graph_def));
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
@ -658,6 +662,45 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
|
||||
}
|
||||
|
||||
void TestRenameNodeInputsWithIgnores() {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
const int width = 10;
|
||||
|
||||
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||
test::FillIota<float>(&a_data, 1.0f);
|
||||
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||
|
||||
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||
test::FillIota<float>(&b_data, 1.0f);
|
||||
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||
|
||||
Output add = Add(root.WithOpName("add"), a_const, a_const);
|
||||
|
||||
Output add2 = Add(root.WithOpName("add2"), a_const, a_const);
|
||||
|
||||
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||
|
||||
Output mul = Mul(root.WithOpName("mul"), add, placeholder);
|
||||
|
||||
Output mul2 = Mul(root.WithOpName("output"), mul, add2);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
GraphDef renamed_graph_def;
|
||||
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"},
|
||||
&renamed_graph_def));
|
||||
|
||||
std::map<string, const NodeDef*> node_map;
|
||||
MapNamesToNodes(renamed_graph_def, &node_map);
|
||||
EXPECT_EQ("b", node_map.at("add")->input(0));
|
||||
EXPECT_EQ("b", node_map.at("add")->input(1));
|
||||
EXPECT_EQ("a", node_map.at("add2")->input(0));
|
||||
EXPECT_EQ("a", node_map.at("add2")->input(1));
|
||||
}
|
||||
|
||||
void TestFindInvalidInputs() {
|
||||
GraphDef graph_def;
|
||||
|
||||
@ -943,20 +986,36 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
EXPECT_EQ("d", value);
|
||||
}
|
||||
|
||||
void TestGetOneIntParameter() {
|
||||
void TestGetOneInt32Parameter() {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"10", "20"}});
|
||||
context.params.insert({"bar", {"-23"}});
|
||||
context.params.insert({"not_a_number", {"not_numerical"}});
|
||||
context.params.insert({"float", {"-23.232323"}});
|
||||
int32 value;
|
||||
TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value));
|
||||
EXPECT_EQ(-23, value);
|
||||
EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok());
|
||||
TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value));
|
||||
EXPECT_EQ(10, value);
|
||||
EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok());
|
||||
EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok());
|
||||
}
|
||||
|
||||
void TestGetOneInt64Parameter() {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"10", "20"}});
|
||||
context.params.insert({"bar", {"-23"}});
|
||||
context.params.insert({"not_a_number", {"not_numerical"}});
|
||||
context.params.insert({"float", {"-23.232323"}});
|
||||
int64 value;
|
||||
TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value));
|
||||
TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value));
|
||||
EXPECT_EQ(-23, value);
|
||||
EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok());
|
||||
TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value));
|
||||
EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok());
|
||||
TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value));
|
||||
EXPECT_EQ(10, value);
|
||||
EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok());
|
||||
EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok());
|
||||
EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok());
|
||||
EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok());
|
||||
}
|
||||
|
||||
void TestGetOneFloatParameter() {
|
||||
@ -1111,6 +1170,10 @@ TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
|
||||
TestRenameNodeInputsWithWildcard();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) {
|
||||
TestRenameNodeInputsWithIgnores();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
|
||||
@ -1127,7 +1190,13 @@ TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
|
||||
TestGetOneStringParameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); }
|
||||
TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) {
|
||||
TestGetOneInt32Parameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) {
|
||||
TestGetOneInt64Parameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
|
||||
TestGetOneFloatParameter();
|
||||
|
Loading…
Reference in New Issue
Block a user