Optimize eight-bit graphs by removing RequantizationRanges

Change: 144145086
This commit is contained in:
Pete Warden 2017-01-10 16:46:05 -08:00 committed by TensorFlower Gardener
parent 5d9a05a8b0
commit b42ba8aace
13 changed files with 1050 additions and 50 deletions

View File

@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {

View File

@ -60,7 +60,9 @@ cc_library(
"fold_batch_norms.cc",
"fold_constants_lib.cc",
"fold_old_batch_norms.cc",
"freeze_requantization_ranges.cc",
"fuse_convolutions.cc",
"insert_logging.cc",
"obsfucate_names.cc",
"quantize_nodes.cc",
"quantize_weights.cc",
@ -101,7 +103,9 @@ tf_cc_test(
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
"fold_old_batch_norms_test.cc",
"freeze_requantization_ranges_test.cc",
"fuse_convolutions_test.cc",
"insert_logging_test.cc",
"obsfucate_names_test.cc",
"quantize_nodes_test.cc",
"quantize_weights_test.cc",

View File

@ -16,7 +16,9 @@
* [fold_batch_norms](#fold_batch_norms)
* [fold_constants](#fold_constants)
* [fold_old_batch_norms](#fold_old_batch_norms)
* [freeze_requantization_ranges](#freeze_requantization_ranges)
* [fuse_convolutions](#fuse_convolutions)
* [insert_logging](#insert_logging)
* [merge_duplicate_nodes](#merge_duplicate_nodes)
* [obsfucate_names](#obsfucate_names)
* [quantize_nodes](#quantize_nodes)
@ -334,7 +336,7 @@ within the saved model, and sets them to the defined default for that attribute.
### fold_batch_norms
Args: None
Args: None \
Prerequisites: [fold_constants](#fold_constants)
This transform tries to optimize away the Mul that's introduced after a Conv2D
@ -347,7 +349,7 @@ produced by training for the Mul input is collapsed down into a simple constant.
### fold_constants
Args: None\
Args: None \
Prerequisites: None
Looks for any sub-graphs within the model that always evaluate to constant
@ -359,7 +361,7 @@ to continue on past transient errors, since this is just an optimization phase.
### fold_old_batch_norms
Args: None\
Args: None \
Prerequisites: None
In the early days of TensorFlow, batch normalization was implemented using a
@ -370,21 +372,143 @@ have a graph that uses the older-style, this transform will recognize and
optimize those ops for inference, in the same way that the
[fold_batch_norms](#fold_batch_norms) transform does for the new approach.
### freeze_requantization_ranges
Args:
* min_max_log_file: Path to a log file containing ranges for ops.
* min_percentile: Percentage cutoff to use to calculate an overall min.
Defaults to 5.
* max_percentile: Percentage cutoff to use to calculate an overall max.
Defaults to 5.
Quantized operations like convolution or matrix multiplies take their inputs as
8-bit, but produce 32-bit results. To do further operations on these, they need
to be converted back down to the lower depth. To make the most of those eight
bits, you need to scale the thirty-two bits of original data down using a scale
that matches the range that's actually being used.
Because that range information isn't stored in the original graph, the
[quantization process](#eight-bit-calculations) inserts RequantizationRange ops
before each conversion from 32 to 8 bits. This op looks at the 32-bit output and
calculates the current min and max every time it's run.
This isn't incredibly time-consuming, but it is extra work that's nice to avoid
if possible. One way of optimizing that away is replacing those
RequantizationRange ops with a pair of Const nodes holding known min/max values,
so the scaling down can be done without having to inspect the output every time.
That's what this transform does. It's usually used in conjunction with a copy of
the graph that's had [insert_logging](#insert_logging) run on it to instrument
it to record the min/max values to stderr. Why is logging used rather than
writing to a normal file? As you'll see later, to get best results you want to
collect data from a lot of runs on real data, and for mobile apps especially
it's a lot easier to do this by copying log files. As an example, here's how
you'd add the logging operations for a quantized version of the Inception v3
graph:
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--logtostderr \
--in_graph=/tmp/quantized_inception.pb \
--out_graph=/tmp/logged_quantized_inception.pb \
--inputs=Mul \
--outputs=softmax \
--transforms='\
insert_logging(op=RequantizationRange, show_name=true, message="__requant_min_max:")\
'
```
Now, when you run the `/tmp/logged_quantized_inception.pb` graph, it will write
out log statements that show the value of the min and max calculated by each
RequantizationRange op. Here's an example of running label_image and saving the
log:
```bash
bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Downloads/grace_hopper.jpg \
--logtostderr \
--input_layer=Mul \
--output_layer=softmax \
--graph=/tmp/logged_quantized_inception.pb \
--labels=learning/brain/models/image/inception_v3/imagenet_comp_graph_label_strings.txt \
--logtostderr \
2>/tmp/min_max_log_small.txt
```
If you look in `/tmp/min_max_log_small.txt`, you'll see a lot of lines like
this:
```
I0108 21:45:42.261883 1972 logging_ops.cc:79] ;conv/Conv2D/eightbit/requant_range__print__;__requant_min_max:[-20.887871][22.274715]
```
This is a simple way of serializing the name of the RequantizationRange op and
its min/max values every time it's run. It's a file like this that you pass into
the transform as the `min_max_log_file` argument. The transform will attempt to
extract all of the min/max values associated with ops, ignoring any irrelevant
lines in the log, and replace the RequantizationRange ops with two Const nodes
containing the found values.
This isn't the whole story though. The min/max values can vary a lot depending
on what the particular inputs to the graph are on any given run, which means
picking ranges based on just one run can lead to clipping of values and a loss
of accuracy. To get better results, you need to run your network against a range
of different inputs. In Inception's case, I often use a thousand different
images from the training set. You can then pass the whole concatenated log from
all of the runs into the transform, and it will pick ranges based on the
aggregate of the values found for each RequantizationRange op.
To ensure that outliers don't increase the range too much, and so decrease the
accuracy by putting too many bits into rare extreme values, the `min_percentile`
and `max_percentile` arguments control how the overall min and max are chosen.
At their default values of 5, this means that the lowest 5% of the minimum
values will be discarded, taking the minimum of the remainder, and the
equivalent for the maximum.
### fuse_convolutions
Args: None\
Args: None \
Prerequisites: None
For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g.
to scale up in the later stages of an image style transfer model),
it can improve memory usage and latency to combine the spatial
transformations with the convolution's im2col patch generation. This transform
looks out for that particular pattern of ops and replaces them with a fused
version that combines the resizing and padding with the convolution.
For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g. to
scale up in the later stages of an image style transfer model), it can improve
memory usage and latency to combine the spatial transformations with the
convolution's im2col patch generation. This transform looks out for that
particular pattern of ops and replaces them with a fused version that combines
the resizing and padding with the convolution.
### insert_logging
Args:
* op: Insert a Print after every occurrence of this op type. Can be repeated
to cover multiple types. If not present, all op types will be instrumented.
* prefix: Insert a Print after every node whose name starts with this value.
Can be repeated to cover multiple nodes. If not present, all node names will
be matched.
* show_op: If true, the op type will be prepended to all log messages.
* show_name: If true, the node's name will be prepended to all log messages.
* message: Arbitrary text to log before the values.
* first_n: How many times to print before suppressing. Defaults to -1, which
means never stop.
* summarize: How long numerical results can be before they're truncated.
Defaults to 1024.
The Print operator writes strings to stderr when it's run inside a graph, and
prints out the numerical results of the node that it's reading from. This can be
very useful when you're debugging and want to follow particular internal values
while a graph is running. This transform allows you to insert those ops at
particular points in the graph, and customize the message that's displayed. It's
also used in conjunction with the
[freeze_requantization_ranges](#freeze_requantization_ranges) transform to
output information that it needs.
### merge_duplicate_nodes
Args: None\
Args: None \
Prerequisites: None
If there are Const nodes with the same types and contents, or nodes with the
@ -396,7 +520,7 @@ duplicates of constants that are used in the quantize/dequantize process).
### obsfucate_names
Args: None\
Args: None \
Prerequisites: None
Replaces all nodes' names with short generated ids, other than the inputs and
@ -429,14 +553,14 @@ Prerequisites: [quantize_weights](#quantize_weights)
Replaces any calculation nodes with their eight-bit equivalents (if available),
and adds in conversion layers to allow remaining float operations to
interoperate. This is one of the most complex transforms, and involves multiple
passes and a lot of rewriting. It's also still an active area of research,
so results may vary depending on the platform and operations you're using in
your model. You should run quantize_weights first to ensure your Const ops are
in eight-bit form.
passes and a lot of rewriting. It's also still an active area of research, so
results may vary depending on the platform and operations you're using in your
model. You should run quantize_weights first to ensure your Const ops are in
eight-bit form.
### quantize_weights
Args: None\
Args: None \
Prerequisites: None
Converts any large (more than 15 element) float Const op into an eight-bit
@ -461,7 +585,7 @@ special circumstances though.
### remove_device
Args: None
Args: None \
Prerequisites: None
All ops can have a hardware device specified. This can be a problem when you're
@ -548,7 +672,7 @@ device assigned.
### sort_by_execution_order
Args: None\
Args: None \
Prerequisites: None
Arranges the nodes in the GraphDef in topological order, so that the inputs of
@ -741,8 +865,8 @@ transform:
This is looking for QuantizeV2 nodes, with three inputs, the first of which is a
Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the
third is a Max which does the same. Assuming we know the Dequantize ops are
pulling from the same eight-bit buffer, the end result of this sub-graph is
a no-op, since it's just turning the eight-bit buffer into float, and then
pulling from the same eight-bit buffer, the end result of this sub-graph is a
no-op, since it's just turning the eight-bit buffer into float, and then
immediately converting it back to eight-bits, so if we look for this pattern and
remove it we can optimize the graph without changing the result.
@ -874,7 +998,7 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps`
parameter:
```C++
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("num_steps", 256, &num_steps));
```
If the conversion fails or the parameter occurs more than once the helper

View File

@ -0,0 +1,213 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
struct MinMaxRecord {
string name;
float min;
float max;
};
// Try to parse a log file containing loosely-structured lines, some of which
// are the min/max logs we want.
Status ExtractMinMaxRecords(const string& log_file_name,
std::vector<MinMaxRecord>* records) {
string file_data;
TF_RETURN_IF_ERROR(
ReadFileToString(Env::Default(), log_file_name, &file_data));
const string print_suffix("__print__");
const string requant_prefix("__requant_min_max:");
std::vector<string> file_lines = str_util::Split(file_data, '\n');
for (const string& file_line : file_lines) {
// We expect to find a line with components separated by semicolons, so to
// start make sure that the basic structure is in place/
StringPiece line(file_line);
if (!line.contains(print_suffix + ";" + requant_prefix)) {
continue;
}
std::vector<string> line_parts = str_util::Split(file_line, ';');
if (line_parts.size() < 2) {
continue;
}
// Now we want to figure out which components have the name and min max
// values by scanning for the prefix we expect.
bool min_max_found = false;
int min_max_index;
for (int i = 1; i < line_parts.size(); ++i) {
StringPiece line_part(line_parts[i]);
if (line_part.starts_with(requant_prefix)) {
min_max_found = true;
min_max_index = i;
}
}
if (!min_max_found) {
continue;
}
// Finally we need to break out the values from the strings, and parse them
// into a form we can use.
string min_max_string = line_parts[min_max_index];
std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
continue;
}
string min_string = min_max_parts[1];
std::vector<string> min_string_parts = str_util::Split(min_string, ']');
if (min_string_parts.size() != 2) {
continue;
}
string min_number_string = min_string_parts[0];
float min;
if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
continue;
}
string max_string = min_max_parts[2];
std::vector<string> max_string_parts = str_util::Split(max_string, ']');
if (max_string_parts.size() != 2) {
continue;
}
string max_number_string = max_string_parts[0];
float max;
if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
continue;
}
StringPiece name_string = line_parts[min_max_index - 1];
if (!name_string.ends_with(print_suffix)) {
continue;
}
string name =
name_string.substr(0, name_string.size() - print_suffix.size())
.ToString();
records->push_back({name, min, max});
}
return Status::OK();
}
// Uses the observed min/max values for requantization captured in a log file to
// replace costly RequantizationRange ops with simple Consts.
Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
string min_max_log_file;
TF_RETURN_IF_ERROR(
context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
if (min_max_log_file == "") {
return errors::InvalidArgument(
"You must pass a file name to min_max_log_file");
}
float min_percentile;
TF_RETURN_IF_ERROR(
context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
float max_percentile;
TF_RETURN_IF_ERROR(
context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
std::vector<MinMaxRecord> records;
TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
if (records.empty()) {
return errors::InvalidArgument(
"No min/max range logs were found in the log file");
}
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(input_graph_def, &node_map);
bool any_missing_nodes = false;
std::map<string, std::vector<MinMaxRecord>> records_by_node;
for (const MinMaxRecord& record : records) {
records_by_node[record.name].push_back(record);
if (!node_map.count(record.name)) {
any_missing_nodes = true;
LOG(WARNING) << "Node from log not found in graph: " << record.name;
}
}
if (any_missing_nodes) {
return errors::InvalidArgument(
"Nodes were found in the log file that aren't present in the graph");
}
// Now find out the largest and smallest min/max values for the node.
std::map<string, std::pair<float, float>> range_for_nodes;
for (const auto& record_info : records_by_node) {
const string& name = record_info.first;
const std::vector<MinMaxRecord> records = record_info.second;
std::vector<float> mins;
std::vector<float> maxs;
for (const MinMaxRecord& record : records) {
mins.push_back(record.min);
maxs.push_back(record.max);
}
std::sort(mins.begin(), mins.end());
std::sort(maxs.begin(), maxs.end());
int min_index = std::round(mins.size() * (min_percentile / 100.0f));
if (min_index < 0) {
min_index = 0;
}
int max_index =
std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
if (max_index > (maxs.size() - 1)) {
max_index = maxs.size() - 1;
}
const float min = mins[min_index];
const float max = maxs[max_index];
range_for_nodes[name] = {min, max};
}
std::map<string, string> inputs_to_rename;
GraphDef frozen_graph_def;
for (const NodeDef& node : input_graph_def.node()) {
if (range_for_nodes.count(node.name())) {
if (node.op() != "RequantizationRange") {
return errors::InvalidArgument(
"Node is expected to be a RequantizationRange op: ", node.name(),
", but is: ", node.op());
}
const float min_value = range_for_nodes.at(node.name()).first;
NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
min_node->set_op("Const");
min_node->set_name(node.name() + "/frozen_min");
SetNodeAttr("dtype", DT_FLOAT, min_node);
Tensor min_tensor(DT_FLOAT, {});
min_tensor.flat<float>()(0) = min_value;
SetNodeTensorAttr<float>("value", min_tensor, min_node);
inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
const float max_value = range_for_nodes.at(node.name()).second;
NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
max_node->set_op("Const");
max_node->set_name(node.name() + "/frozen_max");
SetNodeAttr("dtype", DT_FLOAT, max_node);
Tensor max_tensor(DT_FLOAT, {});
max_tensor.flat<float>()(0) = max_value;
SetNodeTensorAttr<float>("value", max_tensor, max_node);
inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
} else {
NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
new_node->CopyFrom(node);
}
}
RenameNodeInputs(frozen_graph_def, inputs_to_rename,
std::unordered_set<string>(), output_graph_def);
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
FreezeRequantizationRanges);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,200 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
struct MinMaxRecord {
string name;
float min;
float max;
};
Status ExtractMinMaxRecords(const string& log_file_name,
std::vector<MinMaxRecord>* records);
class FreezeRequantizationRangesTest : public ::testing::Test {
protected:
void TestFreezeRequantizationRanges() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6}));
test::FillValues<quint8>(&quantized_tensor, {0, 0, 0, 0, 0, 0});
Output quantized_op = Const(root.WithOpName("quantized_op"),
Input::Initializer(quantized_tensor));
Tensor quantized_min_tensor(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&quantized_min_tensor, {2.0f});
Output quantized_min_op = Const(root.WithOpName("quantized_min_op"),
Input::Initializer(quantized_min_tensor));
Tensor quantized_max_tensor(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&quantized_max_tensor, {2.0f});
Output quantized_max_op = Const(root.WithOpName("quantized_max_op"),
Input::Initializer(quantized_min_tensor));
Tensor offset_tensor(DT_QUINT8, TensorShape({6}));
test::FillValues<quint8>(&offset_tensor, {1, 2, 3, 4, 5, 6});
Output offset_op =
Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor));
Tensor offset_min_tensor(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&offset_min_tensor, {0.0f});
Output offset_min_op = Const(root.WithOpName("offset_min_op"),
Input::Initializer(offset_min_tensor));
Tensor offset_max_tensor(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&offset_max_tensor, {255.0f});
Output offset_max_op = Const(root.WithOpName("offset_max_op"),
Input::Initializer(offset_max_tensor));
QuantizedBiasAdd quantized_bias_add_op(
root.WithOpName("bias_add_op"), quantized_op, offset_op,
quantized_min_op, quantized_max_op, offset_min_op, offset_max_op,
DT_QINT32);
RequantizationRange requantization_range_op(
root.WithOpName("requantization_range_op"),
quantized_bias_add_op.output, quantized_bias_add_op.min_out,
quantized_bias_add_op.max_out);
Requantize requantize_op(
root.WithOpName("requantize_op"), quantized_bias_add_op.output,
quantized_bias_add_op.min_out, quantized_bias_add_op.max_out,
requantization_range_op.output_min, requantization_range_op.output_max,
DT_QUINT8);
Output dequantize_op =
Dequantize(root.WithOpName("dequantize_op"), requantize_op.output,
requantize_op.output_min, requantize_op.output_max);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
const string min_max_log_file_name =
io::JoinPath(testing::TmpDir(), "min_max_log_file.txt");
{
std::unique_ptr<WritableFile> file;
TF_ASSERT_OK(
Env::Default()->NewWritableFile(min_max_log_file_name, &file));
TF_ASSERT_OK(file->Append("Something irrelevant\n"));
TF_ASSERT_OK(
file->Append("[SomePrefix] "
";requantization_range_op__print__;__requant_min_max:"
"[-2.4313571][10.584145]\n"));
TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
}
TransformFuncContext context;
context.input_names = {};
context.output_names = {"dequantize_op"};
context.params = {{"min_max_log_file", {min_max_log_file_name}}};
GraphDef frozen_graph_def;
TF_EXPECT_OK(
FreezeRequantizationRanges(graph_def, context, &frozen_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(frozen_graph_def, &node_map);
EXPECT_EQ(0, node_map.count("requantization_range_op"));
EXPECT_EQ(1, node_map.count("requantize_op"));
const string& min_input =
NodeNameFromInput(node_map.at("requantize_op")->input(3));
ASSERT_EQ(1, node_map.count(min_input));
EXPECT_EQ("Const", node_map.at(min_input)->op());
const string& max_input =
NodeNameFromInput(node_map.at("requantize_op")->input(4));
ASSERT_EQ(1, node_map.count(max_input));
EXPECT_EQ("Const", node_map.at(max_input)->op());
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(
original_session->Run({}, {"dequantize_op"}, {}, &original_outputs));
std::unique_ptr<Session> frozen_session(NewSession(SessionOptions()));
TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
std::vector<Tensor> frozen_outputs;
TF_ASSERT_OK(
frozen_session->Run({}, {"dequantize_op"}, {}, &frozen_outputs));
ASSERT_EQ(original_outputs.size(), frozen_outputs.size());
ASSERT_EQ(1, frozen_outputs.size());
test::ExpectTensorNear<float>(original_outputs[0], frozen_outputs[0], 0.5);
}
void TestExtractMinMaxRecords() {
const string min_max_log_file_name =
io::JoinPath(testing::TmpDir(), "min_max_log_file2.txt");
{
std::unique_ptr<WritableFile> file;
TF_ASSERT_OK(
Env::Default()->NewWritableFile(min_max_log_file_name, &file));
TF_ASSERT_OK(file->Append("Something irrelevant\n"));
TF_ASSERT_OK(
file->Append("[SomePrefix] "
";requantization_range_op__print__;__requant_min_max:"
"[-2.4313571][10.584145]\n"));
TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
TF_ASSERT_OK(file->Append(
"[SomeOtherPrefix] "
";other_requantization_range_op__print__;__requant_min_max:"
"[-1.0][2.0]\n"));
TF_ASSERT_OK(file->Append("Something else irrelevant\n"));
TF_ASSERT_OK(
file->Append("[SomePrefix] "
";requantization_range_op__print__;__requant_min_max:"
"[-1.bad][2.0]\n"));
}
std::vector<MinMaxRecord> records;
TF_ASSERT_OK(ExtractMinMaxRecords(min_max_log_file_name, &records));
ASSERT_EQ(2, records.size());
EXPECT_EQ("requantization_range_op", records[0].name);
EXPECT_NEAR(-2.4313571f, records[0].min, 1e-5f);
EXPECT_NEAR(10.584145f, records[0].max, 1e-5f);
EXPECT_EQ("other_requantization_range_op", records[1].name);
EXPECT_NEAR(-1.0f, records[1].min, 1e-5f);
EXPECT_NEAR(2.0f, records[1].max, 1e-5f);
}
};
TEST_F(FreezeRequantizationRangesTest, TestFreezeRequantizationRanges) {
TestFreezeRequantizationRanges();
}
TEST_F(FreezeRequantizationRangesTest, TestExtractMinMaxRecords) {
TestExtractMinMaxRecords();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,153 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Clears the device field of all ops in the graph.
Status InsertLogging(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
std::unordered_set<string> ops;
bool has_ops;
if (context.params.count("op")) {
has_ops = true;
for (const string& op : context.params.at("op")) {
ops.insert(op);
}
} else {
has_ops = false;
}
std::unordered_set<string> prefixes;
bool has_prefixes;
if (context.params.count("prefix")) {
has_prefixes = true;
for (const string& prefix : context.params.at("prefix")) {
prefixes.insert(prefix);
}
} else {
has_prefixes = false;
}
string message;
TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message));
bool show_name;
TF_RETURN_IF_ERROR(
context.GetOneBoolParameter("show_name", false, &show_name));
bool show_op;
TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op));
int32 first_n;
TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n));
int32 summarize;
TF_RETURN_IF_ERROR(
context.GetOneInt32Parameter("summarize", 1024, &summarize));
std::unordered_map<string, std::set<int>> node_outputs;
for (const NodeDef& node : input_graph_def.node()) {
for (const string& input : node.input()) {
const string canonical_input = CanonicalInputName(input);
string prefix;
string name;
string suffix;
NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix);
const string output_index_string = suffix.substr(1, suffix.size() - 1);
int32 output_index;
if (!strings::safe_strto32(output_index_string, &output_index)) {
return errors::InvalidArgument("Couldn't understand output number in ",
input);
}
node_outputs[name].insert(output_index);
}
}
std::map<string, string> inputs_to_rename;
std::unordered_set<string> ignore_when_renaming;
GraphDef logged_graph_def;
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = logged_graph_def.mutable_node()->Add();
new_node->CopyFrom(node);
if (node_outputs[node.name()].empty()) {
// There were no outputs found to this node, so skip it.
continue;
}
const bool op_matches = (ops.count(node.op()) > 0);
bool prefix_matches = false;
for (const string& prefix : prefixes) {
if (StringPiece(node.name()).starts_with(prefix)) {
prefix_matches = true;
}
}
// If we're not looking for ops, or we found the right op, and if we're not
// looking for prefixes or we found the right prefix, then add logging here.
if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) {
const string name_suffix = "__print__";
DataTypeVector input_types;
DataTypeVector output_types;
TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types));
NodeDef* print_node = logged_graph_def.mutable_node()->Add();
print_node->set_op("Print");
print_node->set_name(strings::StrCat(node.name(), name_suffix));
string node_message;
if (show_op) {
node_message += ";" + node.op() + ";";
}
if (show_name) {
node_message += ";" + print_node->name() + ";";
}
node_message += message;
SetNodeAttr("message", node_message, print_node);
SetNodeAttr("first_n", first_n, print_node);
SetNodeAttr("summarize", summarize, print_node);
print_node->add_input(node.name() + ":0");
SetNodeAttr("T", output_types[0], print_node);
for (int output_index : node_outputs[node.name()]) {
print_node->add_input(strings::StrCat(node.name(), ":", output_index));
}
SetNodeAttr("U", output_types, print_node);
ignore_when_renaming.insert(print_node->name());
// Rewrite the graph so all references to the first input of the original
// op now pull from the print op instead, so it's executed.
inputs_to_rename[node.name() + ":0"] =
strings::StrCat(node.name(), name_suffix, ":0");
}
}
output_graph_def->Clear();
RenameNodeInputs(logged_graph_def, inputs_to_rename, ignore_when_renaming,
output_graph_def);
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,203 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status InsertLogging(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class InsertLoggingTest : public ::testing::Test {
protected:
void CheckGraphCanRun(const GraphDef& graph_def,
const std::vector<string>& output_names) {
std::unique_ptr<Session> session(NewSession(SessionOptions()));
TF_ASSERT_OK(session->Create(graph_def));
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run({}, output_names, {}, &outputs));
}
void TestInsertLogging() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor const_tensor(DT_FLOAT, TensorShape({10}));
test::FillIota<float>(&const_tensor, 1.0f);
Output const_node1 =
Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
Output const_node2 =
Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
Output const_node3 =
Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
Output add_node2 =
Add(root.WithOpName("add_node2"), const_node1, const_node2);
Output add_node3 =
Add(root.WithOpName("add_node3"), const_node1, const_node3);
Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
Output add_node4 =
Add(root.WithOpName("add_node4"), mul_node1, const_node3);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
CheckGraphCanRun(graph_def, {"add_node4"});
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"add_node4"};
TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
CheckGraphCanRun(result, {"add_node4"});
std::unordered_set<string> print_inputs;
for (const NodeDef& node : result.node()) {
if (node.op() == "Print") {
print_inputs.insert(node.input(0));
}
}
EXPECT_EQ(6, print_inputs.size());
EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
EXPECT_EQ(1, print_inputs.count("add_node2:0"));
EXPECT_EQ(1, print_inputs.count("add_node3:0"));
EXPECT_EQ(0, print_inputs.count("add_node4:0"));
EXPECT_EQ(1, print_inputs.count("const_node1:0"));
EXPECT_EQ(1, print_inputs.count("const_node2:0"));
EXPECT_EQ(1, print_inputs.count("const_node3:0"));
}
void TestInsertLoggingByOpType() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor const_tensor(DT_FLOAT, TensorShape({10}));
test::FillIota<float>(&const_tensor, 1.0f);
Output const_node1 =
Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
Output const_node2 =
Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
Output const_node3 =
Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
Output add_node2 =
Add(root.WithOpName("add_node2"), const_node1, const_node2);
Output add_node3 =
Add(root.WithOpName("add_node3"), const_node1, const_node3);
Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
Output add_node4 =
Add(root.WithOpName("add_node4"), mul_node1, const_node3);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
CheckGraphCanRun(graph_def, {"add_node4"});
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"add_node4"};
context.params.insert(
std::pair<string, std::vector<string>>({"op", {"Mul", "Add"}}));
TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
CheckGraphCanRun(result, {"add_node4"});
std::unordered_set<string> print_inputs;
for (const NodeDef& node : result.node()) {
if (node.op() == "Print") {
print_inputs.insert(node.input(0));
}
}
EXPECT_EQ(3, print_inputs.size());
EXPECT_EQ(1, print_inputs.count("mul_node1:0"));
EXPECT_EQ(1, print_inputs.count("add_node2:0"));
EXPECT_EQ(1, print_inputs.count("add_node3:0"));
EXPECT_EQ(0, print_inputs.count("add_node4:0"));
EXPECT_EQ(0, print_inputs.count("const_node1:0"));
EXPECT_EQ(0, print_inputs.count("const_node2:0"));
EXPECT_EQ(0, print_inputs.count("const_node3:0"));
}
void TestInsertLoggingByPrefix() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor const_tensor(DT_FLOAT, TensorShape({10}));
test::FillIota<float>(&const_tensor, 1.0f);
Output const_node1 =
Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor));
Output const_node2 =
Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor));
Output const_node3 =
Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor));
Output add_node2 =
Add(root.WithOpName("add_node2"), const_node1, const_node2);
Output add_node3 =
Add(root.WithOpName("add_node3"), const_node1, const_node3);
Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3);
Output add_node4 =
Add(root.WithOpName("add_node4"), mul_node1, const_node3);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
CheckGraphCanRun(graph_def, {"add_node4"});
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"add_node4"};
context.params.insert(
std::pair<string, std::vector<string>>({"prefix", {"add_node"}}));
TF_ASSERT_OK(InsertLogging(graph_def, context, &result));
CheckGraphCanRun(result, {"add_node4"});
std::unordered_set<string> print_inputs;
for (const NodeDef& node : result.node()) {
if (node.op() == "Print") {
print_inputs.insert(node.input(0));
}
}
EXPECT_EQ(2, print_inputs.size());
EXPECT_EQ(0, print_inputs.count("mul_node1:0"));
EXPECT_EQ(1, print_inputs.count("add_node2:0"));
EXPECT_EQ(1, print_inputs.count("add_node3:0"));
EXPECT_EQ(0, print_inputs.count("add_node4:0"));
EXPECT_EQ(0, print_inputs.count("const_node1:0"));
EXPECT_EQ(0, print_inputs.count("const_node2:0"));
EXPECT_EQ(0, print_inputs.count("const_node3:0"));
}
};
TEST_F(InsertLoggingTest, TestInsertLogging) { TestInsertLogging(); }
TEST_F(InsertLoggingTest, TestInsertLoggingByOpType) {
TestInsertLoggingByOpType();
}
TEST_F(InsertLoggingTest, TestInsertLoggingByPrefix) {
TestInsertLoggingByPrefix();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -236,7 +236,8 @@ Status MergeDuplicateNodes(const GraphDef& input_graph_def,
}
// Update the graph so that any nodes that referred to removed inputs now
// pull from the remaining duplicate.
RenameNodeInputs(merged_graph_def, inputs_to_rename, &current_graph_def);
RenameNodeInputs(merged_graph_def, inputs_to_rename,
std::unordered_set<string>(), &current_graph_def);
} while (any_duplicates_found);
*output_graph_def = current_graph_def;
@ -295,7 +296,8 @@ Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
},
{true}, &replaced_graph_def));
RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def);
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
std::unordered_set<string>(), output_graph_def);
return Status::OK();
}
@ -372,9 +374,9 @@ Status QuantizePlaceholders(const GraphDef& input_graph_def,
GraphDef first_pass_graph_def;
RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
&first_pass_graph_def);
std::unordered_set<string>(), &first_pass_graph_def);
RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
output_graph_def);
std::unordered_set<string>(), output_graph_def);
return Status::OK();
}

View File

@ -80,7 +80,7 @@ Status RemoveNodes(const GraphDef& input_graph_def,
{true}, &replaced_graph_def));
// Make sure all references to removed nodes now point to their inputs.
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
&current_graph_def);
std::unordered_set<string>(), &current_graph_def);
} while (any_nodes_removed);
}

View File

@ -33,8 +33,9 @@ namespace graph_transforms {
Status RoundWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
int64 num_steps;
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
int32 num_steps;
TF_RETURN_IF_ERROR(
context.GetOneInt32Parameter("num_steps", 256, &num_steps));
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"Const"},
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,

View File

@ -469,6 +469,7 @@ Status ReplaceMatchingOpTypes(
Status RenameNodeInputs(const GraphDef& input_graph_def,
const std::map<string, string>& inputs_to_rename,
const std::unordered_set<string>& nodes_to_ignore,
GraphDef* output_graph_def) {
std::map<string, std::vector<std::pair<string, string>>>
canonical_inputs_to_rename;
@ -494,6 +495,9 @@ Status RenameNodeInputs(const GraphDef& input_graph_def,
input_node_name);
}
already_visited.insert(input_node_name);
if (nodes_to_ignore.count(node.name())) {
break;
}
bool any_match_found = false;
for (const std::pair<string, string>& input_to_rename :
canonical_inputs_to_rename.at(input_node_name)) {
@ -630,9 +634,26 @@ Status TransformFuncContext::GetOneStringParameter(const string& name,
}
}
Status TransformFuncContext::GetOneIntParameter(const string& name,
int64 default_value,
int64* result) const {
Status TransformFuncContext::GetOneInt32Parameter(const string& name,
int32 default_value,
int32* result) const {
const int params_count = CountParameters(name);
if (params_count == 0) {
*result = default_value;
return Status::OK();
}
string string_value;
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
if (!strings::safe_strto32(StringPiece(string_value), result)) {
return errors::InvalidArgument("Couldn't interpret the ", name,
" argument as a number:", string_value);
}
return Status::OK();
}
Status TransformFuncContext::GetOneInt64Parameter(const string& name,
int64 default_value,
int64* result) const {
const int params_count = CountParameters(name);
if (params_count == 0) {
*result = default_value;

View File

@ -17,8 +17,10 @@ limitations under the License.
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
#include <set>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@ -201,9 +203,11 @@ Status ReplaceMatchingOpTypes(
// Returns a list of the unique nodes found in this match.
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
// Changes all input references to a particular node name.
// Changes all input references to a particular node name. Any nodes with names
// listed in nodes_to_ignore will not have their inputs rewritten.
Status RenameNodeInputs(const GraphDef& input_graph_def,
const std::map<string, string>& inputs_to_rename,
const std::unordered_set<string>& nodes_to_ignore,
GraphDef* output_graph_def);
// Utility function that copies all the nodes found in a match into the
@ -225,11 +229,17 @@ struct TransformFuncContext {
Status GetOneStringParameter(const string& name, const string& default_value,
string* result) const;
// Gets a single occurrence of a parameter as an integer, falling back to a
// default if it isn't present and returning an error if it isn't convertible
// to a number.
Status GetOneIntParameter(const string& name, int64 default_value,
int64* result) const;
// Gets a single occurrence of a parameter as a 32-bit integer, falling back
// to a default if it isn't present and returning an error if it isn't
// convertible to a number.
Status GetOneInt32Parameter(const string& name, int32 default_value,
int32* result) const;
// Gets a single occurrence of a parameter as a 64-bit integer, falling back
// to a default if it isn't present and returning an error if it isn't
// convertible to a number.
Status GetOneInt64Parameter(const string& name, int64 default_value,
int64* result) const;
// Gets a single occurrence of a parameter as a floating point number, falling
// back to a default if it isn't present and returning an error if it isn't

View File

@ -541,7 +541,9 @@ class TransformUtilsTest : public ::testing::Test {
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def));
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}},
std::unordered_set<string>(),
&renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
@ -579,7 +581,7 @@ class TransformUtilsTest : public ::testing::Test {
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(
graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
&renamed_graph_def));
std::unordered_set<string>(), &renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
@ -615,8 +617,9 @@ class TransformUtilsTest : public ::testing::Test {
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
&renamed_graph_def);
Status rename_status =
RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
std::unordered_set<string>(), &renamed_graph_def);
EXPECT_FALSE(rename_status.ok());
}
@ -650,6 +653,7 @@ class TransformUtilsTest : public ::testing::Test {
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
std::unordered_set<string>(),
&renamed_graph_def));
std::map<string, const NodeDef*> node_map;
@ -658,6 +662,45 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
}
void TestRenameNodeInputsWithIgnores() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Output add = Add(root.WithOpName("add"), a_const, a_const);
Output add2 = Add(root.WithOpName("add2"), a_const, a_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("mul"), add, placeholder);
Output mul2 = Mul(root.WithOpName("output"), mul, add2);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"},
&renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
EXPECT_EQ("b", node_map.at("add")->input(0));
EXPECT_EQ("b", node_map.at("add")->input(1));
EXPECT_EQ("a", node_map.at("add2")->input(0));
EXPECT_EQ("a", node_map.at("add2")->input(1));
}
void TestFindInvalidInputs() {
GraphDef graph_def;
@ -943,20 +986,36 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ("d", value);
}
void TestGetOneIntParameter() {
void TestGetOneInt32Parameter() {
TransformFuncContext context;
context.params.insert({"foo", {"10", "20"}});
context.params.insert({"bar", {"-23"}});
context.params.insert({"not_a_number", {"not_numerical"}});
context.params.insert({"float", {"-23.232323"}});
int32 value;
TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value));
EXPECT_EQ(-23, value);
EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok());
TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value));
EXPECT_EQ(10, value);
EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok());
EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok());
}
void TestGetOneInt64Parameter() {
TransformFuncContext context;
context.params.insert({"foo", {"10", "20"}});
context.params.insert({"bar", {"-23"}});
context.params.insert({"not_a_number", {"not_numerical"}});
context.params.insert({"float", {"-23.232323"}});
int64 value;
TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value));
TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value));
EXPECT_EQ(-23, value);
EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok());
TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value));
EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok());
TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value));
EXPECT_EQ(10, value);
EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok());
EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok());
EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok());
EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok());
}
void TestGetOneFloatParameter() {
@ -1111,6 +1170,10 @@ TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
TestRenameNodeInputsWithWildcard();
}
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) {
TestRenameNodeInputsWithIgnores();
}
TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
@ -1127,7 +1190,13 @@ TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
TestGetOneStringParameter();
}
TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); }
TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) {
TestGetOneInt32Parameter();
}
TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) {
TestGetOneInt64Parameter();
}
TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
TestGetOneFloatParameter();