diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index e5807840cef..67d603dd0ae 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 3bd2cc9ac92..8bd18bec099 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -60,7 +60,9 @@ cc_library( "fold_batch_norms.cc", "fold_constants_lib.cc", "fold_old_batch_norms.cc", + "freeze_requantization_ranges.cc", "fuse_convolutions.cc", + "insert_logging.cc", "obsfucate_names.cc", "quantize_nodes.cc", "quantize_weights.cc", @@ -101,7 +103,9 @@ tf_cc_test( "fold_batch_norms_test.cc", "fold_constants_test.cc", "fold_old_batch_norms_test.cc", + "freeze_requantization_ranges_test.cc", "fuse_convolutions_test.cc", + "insert_logging_test.cc", "obsfucate_names_test.cc", "quantize_nodes_test.cc", "quantize_weights_test.cc", diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 334c65117ee..37545feb4f0 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -16,7 +16,9 @@ * [fold_batch_norms](#fold_batch_norms) * [fold_constants](#fold_constants) * [fold_old_batch_norms](#fold_old_batch_norms) + * [freeze_requantization_ranges](#freeze_requantization_ranges) * [fuse_convolutions](#fuse_convolutions) + * [insert_logging](#insert_logging) * [merge_duplicate_nodes](#merge_duplicate_nodes) * [obsfucate_names](#obsfucate_names) * [quantize_nodes](#quantize_nodes) @@ -334,7 +336,7 @@ within the saved model, and sets them to the defined default for that attribute. ### fold_batch_norms -Args: None +Args: None \ Prerequisites: [fold_constants](#fold_constants) This transform tries to optimize away the Mul that's introduced after a Conv2D @@ -347,7 +349,7 @@ produced by training for the Mul input is collapsed down into a simple constant. ### fold_constants -Args: None\ +Args: None \ Prerequisites: None Looks for any sub-graphs within the model that always evaluate to constant @@ -359,7 +361,7 @@ to continue on past transient errors, since this is just an optimization phase. ### fold_old_batch_norms -Args: None\ +Args: None \ Prerequisites: None In the early days of TensorFlow, batch normalization was implemented using a @@ -370,21 +372,143 @@ have a graph that uses the older-style, this transform will recognize and optimize those ops for inference, in the same way that the [fold_batch_norms](#fold_batch_norms) transform does for the new approach. +### freeze_requantization_ranges + +Args: + +* min_max_log_file: Path to a log file containing ranges for ops. +* min_percentile: Percentage cutoff to use to calculate an overall min. + Defaults to 5. +* max_percentile: Percentage cutoff to use to calculate an overall max. + Defaults to 5. + +Quantized operations like convolution or matrix multiplies take their inputs as +8-bit, but produce 32-bit results. To do further operations on these, they need +to be converted back down to the lower depth. To make the most of those eight +bits, you need to scale the thirty-two bits of original data down using a scale +that matches the range that's actually being used. + +Because that range information isn't stored in the original graph, the +[quantization process](#eight-bit-calculations) inserts RequantizationRange ops +before each conversion from 32 to 8 bits. This op looks at the 32-bit output and +calculates the current min and max every time it's run. + +This isn't incredibly time-consuming, but it is extra work that's nice to avoid +if possible. One way of optimizing that away is replacing those +RequantizationRange ops with a pair of Const nodes holding known min/max values, +so the scaling down can be done without having to inspect the output every time. + +That's what this transform does. It's usually used in conjunction with a copy of +the graph that's had [insert_logging](#insert_logging) run on it to instrument +it to record the min/max values to stderr. Why is logging used rather than +writing to a normal file? As you'll see later, to get best results you want to +collect data from a lot of runs on real data, and for mobile apps especially +it's a lot easier to do this by copying log files. As an example, here's how +you'd add the logging operations for a quantized version of the Inception v3 +graph: + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--logtostderr \ +--in_graph=/tmp/quantized_inception.pb \ +--out_graph=/tmp/logged_quantized_inception.pb \ +--inputs=Mul \ +--outputs=softmax \ +--transforms='\ +insert_logging(op=RequantizationRange, show_name=true, message="__requant_min_max:")\ +' +``` + +Now, when you run the `/tmp/logged_quantized_inception.pb` graph, it will write +out log statements that show the value of the min and max calculated by each +RequantizationRange op. Here's an example of running label_image and saving the +log: + +```bash +bazel build tensorflow/examples/label_image:label_image +bazel-bin/tensorflow/examples/label_image/label_image \ +--image=${HOME}/Downloads/grace_hopper.jpg \ +--logtostderr \ +--input_layer=Mul \ +--output_layer=softmax \ +--graph=/tmp/logged_quantized_inception.pb \ +--labels=learning/brain/models/image/inception_v3/imagenet_comp_graph_label_strings.txt \ +--logtostderr \ +2>/tmp/min_max_log_small.txt +``` + +If you look in `/tmp/min_max_log_small.txt`, you'll see a lot of lines like +this: + +``` +I0108 21:45:42.261883 1972 logging_ops.cc:79] ;conv/Conv2D/eightbit/requant_range__print__;__requant_min_max:[-20.887871][22.274715] +``` + +This is a simple way of serializing the name of the RequantizationRange op and +its min/max values every time it's run. It's a file like this that you pass into +the transform as the `min_max_log_file` argument. The transform will attempt to +extract all of the min/max values associated with ops, ignoring any irrelevant +lines in the log, and replace the RequantizationRange ops with two Const nodes +containing the found values. + +This isn't the whole story though. The min/max values can vary a lot depending +on what the particular inputs to the graph are on any given run, which means +picking ranges based on just one run can lead to clipping of values and a loss +of accuracy. To get better results, you need to run your network against a range +of different inputs. In Inception's case, I often use a thousand different +images from the training set. You can then pass the whole concatenated log from +all of the runs into the transform, and it will pick ranges based on the +aggregate of the values found for each RequantizationRange op. + +To ensure that outliers don't increase the range too much, and so decrease the +accuracy by putting too many bits into rare extreme values, the `min_percentile` +and `max_percentile` arguments control how the overall min and max are chosen. +At their default values of 5, this means that the lowest 5% of the minimum +values will be discarded, taking the minimum of the remainder, and the +equivalent for the maximum. + ### fuse_convolutions -Args: None\ +Args: None \ Prerequisites: None -For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g. -to scale up in the later stages of an image style transfer model), -it can improve memory usage and latency to combine the spatial -transformations with the convolution's im2col patch generation. This transform -looks out for that particular pattern of ops and replaces them with a fused -version that combines the resizing and padding with the convolution. +For graphs that use ResizeBilinear or MirrorPad ops before convolutions (e.g. to +scale up in the later stages of an image style transfer model), it can improve +memory usage and latency to combine the spatial transformations with the +convolution's im2col patch generation. This transform looks out for that +particular pattern of ops and replaces them with a fused version that combines +the resizing and padding with the convolution. + +### insert_logging + +Args: + +* op: Insert a Print after every occurrence of this op type. Can be repeated + to cover multiple types. If not present, all op types will be instrumented. +* prefix: Insert a Print after every node whose name starts with this value. + Can be repeated to cover multiple nodes. If not present, all node names will + be matched. +* show_op: If true, the op type will be prepended to all log messages. +* show_name: If true, the node's name will be prepended to all log messages. +* message: Arbitrary text to log before the values. +* first_n: How many times to print before suppressing. Defaults to -1, which + means never stop. +* summarize: How long numerical results can be before they're truncated. + Defaults to 1024. + +The Print operator writes strings to stderr when it's run inside a graph, and +prints out the numerical results of the node that it's reading from. This can be +very useful when you're debugging and want to follow particular internal values +while a graph is running. This transform allows you to insert those ops at +particular points in the graph, and customize the message that's displayed. It's +also used in conjunction with the +[freeze_requantization_ranges](#freeze_requantization_ranges) transform to +output information that it needs. ### merge_duplicate_nodes -Args: None\ +Args: None \ Prerequisites: None If there are Const nodes with the same types and contents, or nodes with the @@ -396,7 +520,7 @@ duplicates of constants that are used in the quantize/dequantize process). ### obsfucate_names -Args: None\ +Args: None \ Prerequisites: None Replaces all nodes' names with short generated ids, other than the inputs and @@ -429,14 +553,14 @@ Prerequisites: [quantize_weights](#quantize_weights) Replaces any calculation nodes with their eight-bit equivalents (if available), and adds in conversion layers to allow remaining float operations to interoperate. This is one of the most complex transforms, and involves multiple -passes and a lot of rewriting. It's also still an active area of research, -so results may vary depending on the platform and operations you're using in -your model. You should run quantize_weights first to ensure your Const ops are -in eight-bit form. +passes and a lot of rewriting. It's also still an active area of research, so +results may vary depending on the platform and operations you're using in your +model. You should run quantize_weights first to ensure your Const ops are in +eight-bit form. ### quantize_weights -Args: None\ +Args: None \ Prerequisites: None Converts any large (more than 15 element) float Const op into an eight-bit @@ -461,7 +585,7 @@ special circumstances though. ### remove_device -Args: None +Args: None \ Prerequisites: None All ops can have a hardware device specified. This can be a problem when you're @@ -548,7 +672,7 @@ device assigned. ### sort_by_execution_order -Args: None\ +Args: None \ Prerequisites: None Arranges the nodes in the GraphDef in topological order, so that the inputs of @@ -741,8 +865,8 @@ transform: This is looking for QuantizeV2 nodes, with three inputs, the first of which is a Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the third is a Max which does the same. Assuming we know the Dequantize ops are -pulling from the same eight-bit buffer, the end result of this sub-graph is -a no-op, since it's just turning the eight-bit buffer into float, and then +pulling from the same eight-bit buffer, the end result of this sub-graph is a +no-op, since it's just turning the eight-bit buffer into float, and then immediately converting it back to eight-bits, so if we look for this pattern and remove it we can optimize the graph without changing the result. @@ -874,7 +998,7 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps` parameter: ```C++ -TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps)); +TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("num_steps", 256, &num_steps)); ``` If the conversion fails or the parameter occurs more than once the helper diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc new file mode 100644 index 00000000000..8faad4a442d --- /dev/null +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc @@ -0,0 +1,213 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +struct MinMaxRecord { + string name; + float min; + float max; +}; + +// Try to parse a log file containing loosely-structured lines, some of which +// are the min/max logs we want. +Status ExtractMinMaxRecords(const string& log_file_name, + std::vector* records) { + string file_data; + TF_RETURN_IF_ERROR( + ReadFileToString(Env::Default(), log_file_name, &file_data)); + const string print_suffix("__print__"); + const string requant_prefix("__requant_min_max:"); + std::vector file_lines = str_util::Split(file_data, '\n'); + for (const string& file_line : file_lines) { + // We expect to find a line with components separated by semicolons, so to + // start make sure that the basic structure is in place/ + StringPiece line(file_line); + if (!line.contains(print_suffix + ";" + requant_prefix)) { + continue; + } + std::vector line_parts = str_util::Split(file_line, ';'); + if (line_parts.size() < 2) { + continue; + } + // Now we want to figure out which components have the name and min max + // values by scanning for the prefix we expect. + bool min_max_found = false; + int min_max_index; + for (int i = 1; i < line_parts.size(); ++i) { + StringPiece line_part(line_parts[i]); + if (line_part.starts_with(requant_prefix)) { + min_max_found = true; + min_max_index = i; + } + } + if (!min_max_found) { + continue; + } + // Finally we need to break out the values from the strings, and parse them + // into a form we can use. + string min_max_string = line_parts[min_max_index]; + std::vector min_max_parts = str_util::Split(min_max_string, '['); + if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) { + continue; + } + string min_string = min_max_parts[1]; + std::vector min_string_parts = str_util::Split(min_string, ']'); + if (min_string_parts.size() != 2) { + continue; + } + string min_number_string = min_string_parts[0]; + float min; + if (!strings::safe_strtof(min_number_string.c_str(), &min)) { + continue; + } + string max_string = min_max_parts[2]; + std::vector max_string_parts = str_util::Split(max_string, ']'); + if (max_string_parts.size() != 2) { + continue; + } + string max_number_string = max_string_parts[0]; + float max; + if (!strings::safe_strtof(max_number_string.c_str(), &max)) { + continue; + } + StringPiece name_string = line_parts[min_max_index - 1]; + if (!name_string.ends_with(print_suffix)) { + continue; + } + string name = + name_string.substr(0, name_string.size() - print_suffix.size()) + .ToString(); + records->push_back({name, min, max}); + } + return Status::OK(); +} + +// Uses the observed min/max values for requantization captured in a log file to +// replace costly RequantizationRange ops with simple Consts. +Status FreezeRequantizationRanges(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + string min_max_log_file; + TF_RETURN_IF_ERROR( + context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file)); + if (min_max_log_file == "") { + return errors::InvalidArgument( + "You must pass a file name to min_max_log_file"); + } + float min_percentile; + TF_RETURN_IF_ERROR( + context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile)); + float max_percentile; + TF_RETURN_IF_ERROR( + context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile)); + + std::vector records; + TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records)); + if (records.empty()) { + return errors::InvalidArgument( + "No min/max range logs were found in the log file"); + } + + std::map node_map; + MapNamesToNodes(input_graph_def, &node_map); + bool any_missing_nodes = false; + std::map> records_by_node; + for (const MinMaxRecord& record : records) { + records_by_node[record.name].push_back(record); + if (!node_map.count(record.name)) { + any_missing_nodes = true; + LOG(WARNING) << "Node from log not found in graph: " << record.name; + } + } + if (any_missing_nodes) { + return errors::InvalidArgument( + "Nodes were found in the log file that aren't present in the graph"); + } + + // Now find out the largest and smallest min/max values for the node. + std::map> range_for_nodes; + for (const auto& record_info : records_by_node) { + const string& name = record_info.first; + const std::vector records = record_info.second; + std::vector mins; + std::vector maxs; + for (const MinMaxRecord& record : records) { + mins.push_back(record.min); + maxs.push_back(record.max); + } + std::sort(mins.begin(), mins.end()); + std::sort(maxs.begin(), maxs.end()); + int min_index = std::round(mins.size() * (min_percentile / 100.0f)); + if (min_index < 0) { + min_index = 0; + } + int max_index = + std::round(maxs.size() * (1.0f - (max_percentile / 100.0f))); + if (max_index > (maxs.size() - 1)) { + max_index = maxs.size() - 1; + } + const float min = mins[min_index]; + const float max = maxs[max_index]; + range_for_nodes[name] = {min, max}; + } + std::map inputs_to_rename; + GraphDef frozen_graph_def; + for (const NodeDef& node : input_graph_def.node()) { + if (range_for_nodes.count(node.name())) { + if (node.op() != "RequantizationRange") { + return errors::InvalidArgument( + "Node is expected to be a RequantizationRange op: ", node.name(), + ", but is: ", node.op()); + } + const float min_value = range_for_nodes.at(node.name()).first; + NodeDef* min_node = frozen_graph_def.mutable_node()->Add(); + min_node->set_op("Const"); + min_node->set_name(node.name() + "/frozen_min"); + SetNodeAttr("dtype", DT_FLOAT, min_node); + Tensor min_tensor(DT_FLOAT, {}); + min_tensor.flat()(0) = min_value; + SetNodeTensorAttr("value", min_tensor, min_node); + inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0"; + + const float max_value = range_for_nodes.at(node.name()).second; + NodeDef* max_node = frozen_graph_def.mutable_node()->Add(); + max_node->set_op("Const"); + max_node->set_name(node.name() + "/frozen_max"); + SetNodeAttr("dtype", DT_FLOAT, max_node); + Tensor max_tensor(DT_FLOAT, {}); + max_tensor.flat()(0) = max_value; + SetNodeTensorAttr("value", max_tensor, max_node); + inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0"; + } else { + NodeDef* new_node = frozen_graph_def.mutable_node()->Add(); + new_node->CopyFrom(node); + } + } + RenameNodeInputs(frozen_graph_def, inputs_to_rename, + std::unordered_set(), output_graph_def); + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges", + FreezeRequantizationRanges); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc new file mode 100644 index 00000000000..ab6b2ffef0f --- /dev/null +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status FreezeRequantizationRanges(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +struct MinMaxRecord { + string name; + float min; + float max; +}; +Status ExtractMinMaxRecords(const string& log_file_name, + std::vector* records); + +class FreezeRequantizationRangesTest : public ::testing::Test { + protected: + void TestFreezeRequantizationRanges() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6})); + test::FillValues(&quantized_tensor, {0, 0, 0, 0, 0, 0}); + Output quantized_op = Const(root.WithOpName("quantized_op"), + Input::Initializer(quantized_tensor)); + + Tensor quantized_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_min_tensor, {2.0f}); + Output quantized_min_op = Const(root.WithOpName("quantized_min_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor quantized_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_max_tensor, {2.0f}); + Output quantized_max_op = Const(root.WithOpName("quantized_max_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor offset_tensor(DT_QUINT8, TensorShape({6})); + test::FillValues(&offset_tensor, {1, 2, 3, 4, 5, 6}); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Tensor offset_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_min_tensor, {0.0f}); + Output offset_min_op = Const(root.WithOpName("offset_min_op"), + Input::Initializer(offset_min_tensor)); + + Tensor offset_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_max_tensor, {255.0f}); + Output offset_max_op = Const(root.WithOpName("offset_max_op"), + Input::Initializer(offset_max_tensor)); + + QuantizedBiasAdd quantized_bias_add_op( + root.WithOpName("bias_add_op"), quantized_op, offset_op, + quantized_min_op, quantized_max_op, offset_min_op, offset_max_op, + DT_QINT32); + + RequantizationRange requantization_range_op( + root.WithOpName("requantization_range_op"), + quantized_bias_add_op.output, quantized_bias_add_op.min_out, + quantized_bias_add_op.max_out); + + Requantize requantize_op( + root.WithOpName("requantize_op"), quantized_bias_add_op.output, + quantized_bias_add_op.min_out, quantized_bias_add_op.max_out, + requantization_range_op.output_min, requantization_range_op.output_max, + DT_QUINT8); + + Output dequantize_op = + Dequantize(root.WithOpName("dequantize_op"), requantize_op.output, + requantize_op.output_min, requantize_op.output_max); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + const string min_max_log_file_name = + io::JoinPath(testing::TmpDir(), "min_max_log_file.txt"); + { + std::unique_ptr file; + TF_ASSERT_OK( + Env::Default()->NewWritableFile(min_max_log_file_name, &file)); + TF_ASSERT_OK(file->Append("Something irrelevant\n")); + TF_ASSERT_OK( + file->Append("[SomePrefix] " + ";requantization_range_op__print__;__requant_min_max:" + "[-2.4313571][10.584145]\n")); + TF_ASSERT_OK(file->Append("Something else irrelevant\n")); + } + + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"dequantize_op"}; + context.params = {{"min_max_log_file", {min_max_log_file_name}}}; + + GraphDef frozen_graph_def; + TF_EXPECT_OK( + FreezeRequantizationRanges(graph_def, context, &frozen_graph_def)); + + std::map node_map; + MapNamesToNodes(frozen_graph_def, &node_map); + EXPECT_EQ(0, node_map.count("requantization_range_op")); + EXPECT_EQ(1, node_map.count("requantize_op")); + const string& min_input = + NodeNameFromInput(node_map.at("requantize_op")->input(3)); + ASSERT_EQ(1, node_map.count(min_input)); + EXPECT_EQ("Const", node_map.at(min_input)->op()); + const string& max_input = + NodeNameFromInput(node_map.at("requantize_op")->input(4)); + ASSERT_EQ(1, node_map.count(max_input)); + EXPECT_EQ("Const", node_map.at(max_input)->op()); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(graph_def)); + std::vector original_outputs; + TF_ASSERT_OK( + original_session->Run({}, {"dequantize_op"}, {}, &original_outputs)); + + std::unique_ptr frozen_session(NewSession(SessionOptions())); + TF_ASSERT_OK(frozen_session->Create(frozen_graph_def)); + std::vector frozen_outputs; + TF_ASSERT_OK( + frozen_session->Run({}, {"dequantize_op"}, {}, &frozen_outputs)); + + ASSERT_EQ(original_outputs.size(), frozen_outputs.size()); + ASSERT_EQ(1, frozen_outputs.size()); + test::ExpectTensorNear(original_outputs[0], frozen_outputs[0], 0.5); + } + + void TestExtractMinMaxRecords() { + const string min_max_log_file_name = + io::JoinPath(testing::TmpDir(), "min_max_log_file2.txt"); + { + std::unique_ptr file; + TF_ASSERT_OK( + Env::Default()->NewWritableFile(min_max_log_file_name, &file)); + TF_ASSERT_OK(file->Append("Something irrelevant\n")); + TF_ASSERT_OK( + file->Append("[SomePrefix] " + ";requantization_range_op__print__;__requant_min_max:" + "[-2.4313571][10.584145]\n")); + TF_ASSERT_OK(file->Append("Something else irrelevant\n")); + TF_ASSERT_OK(file->Append( + "[SomeOtherPrefix] " + ";other_requantization_range_op__print__;__requant_min_max:" + "[-1.0][2.0]\n")); + TF_ASSERT_OK(file->Append("Something else irrelevant\n")); + TF_ASSERT_OK( + file->Append("[SomePrefix] " + ";requantization_range_op__print__;__requant_min_max:" + "[-1.bad][2.0]\n")); + } + std::vector records; + TF_ASSERT_OK(ExtractMinMaxRecords(min_max_log_file_name, &records)); + ASSERT_EQ(2, records.size()); + EXPECT_EQ("requantization_range_op", records[0].name); + EXPECT_NEAR(-2.4313571f, records[0].min, 1e-5f); + EXPECT_NEAR(10.584145f, records[0].max, 1e-5f); + EXPECT_EQ("other_requantization_range_op", records[1].name); + EXPECT_NEAR(-1.0f, records[1].min, 1e-5f); + EXPECT_NEAR(2.0f, records[1].max, 1e-5f); + } +}; + +TEST_F(FreezeRequantizationRangesTest, TestFreezeRequantizationRanges) { + TestFreezeRequantizationRanges(); +} + +TEST_F(FreezeRequantizationRangesTest, TestExtractMinMaxRecords) { + TestExtractMinMaxRecords(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/insert_logging.cc b/tensorflow/tools/graph_transforms/insert_logging.cc new file mode 100644 index 00000000000..c9d72f3d7d5 --- /dev/null +++ b/tensorflow/tools/graph_transforms/insert_logging.cc @@ -0,0 +1,153 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Clears the device field of all ops in the graph. +Status InsertLogging(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + std::unordered_set ops; + bool has_ops; + if (context.params.count("op")) { + has_ops = true; + for (const string& op : context.params.at("op")) { + ops.insert(op); + } + } else { + has_ops = false; + } + + std::unordered_set prefixes; + bool has_prefixes; + if (context.params.count("prefix")) { + has_prefixes = true; + for (const string& prefix : context.params.at("prefix")) { + prefixes.insert(prefix); + } + } else { + has_prefixes = false; + } + + string message; + TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message)); + + bool show_name; + TF_RETURN_IF_ERROR( + context.GetOneBoolParameter("show_name", false, &show_name)); + + bool show_op; + TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op)); + + int32 first_n; + TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n)); + + int32 summarize; + TF_RETURN_IF_ERROR( + context.GetOneInt32Parameter("summarize", 1024, &summarize)); + + std::unordered_map> node_outputs; + for (const NodeDef& node : input_graph_def.node()) { + for (const string& input : node.input()) { + const string canonical_input = CanonicalInputName(input); + string prefix; + string name; + string suffix; + NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix); + const string output_index_string = suffix.substr(1, suffix.size() - 1); + int32 output_index; + if (!strings::safe_strto32(output_index_string, &output_index)) { + return errors::InvalidArgument("Couldn't understand output number in ", + input); + } + node_outputs[name].insert(output_index); + } + } + + std::map inputs_to_rename; + std::unordered_set ignore_when_renaming; + GraphDef logged_graph_def; + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = logged_graph_def.mutable_node()->Add(); + new_node->CopyFrom(node); + if (node_outputs[node.name()].empty()) { + // There were no outputs found to this node, so skip it. + continue; + } + const bool op_matches = (ops.count(node.op()) > 0); + bool prefix_matches = false; + for (const string& prefix : prefixes) { + if (StringPiece(node.name()).starts_with(prefix)) { + prefix_matches = true; + } + } + // If we're not looking for ops, or we found the right op, and if we're not + // looking for prefixes or we found the right prefix, then add logging here. + if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) { + const string name_suffix = "__print__"; + DataTypeVector input_types; + DataTypeVector output_types; + TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types)); + NodeDef* print_node = logged_graph_def.mutable_node()->Add(); + print_node->set_op("Print"); + print_node->set_name(strings::StrCat(node.name(), name_suffix)); + string node_message; + if (show_op) { + node_message += ";" + node.op() + ";"; + } + if (show_name) { + node_message += ";" + print_node->name() + ";"; + } + node_message += message; + SetNodeAttr("message", node_message, print_node); + SetNodeAttr("first_n", first_n, print_node); + SetNodeAttr("summarize", summarize, print_node); + print_node->add_input(node.name() + ":0"); + SetNodeAttr("T", output_types[0], print_node); + for (int output_index : node_outputs[node.name()]) { + print_node->add_input(strings::StrCat(node.name(), ":", output_index)); + } + SetNodeAttr("U", output_types, print_node); + ignore_when_renaming.insert(print_node->name()); + // Rewrite the graph so all references to the first input of the original + // op now pull from the print op instead, so it's executed. + inputs_to_rename[node.name() + ":0"] = + strings::StrCat(node.name(), name_suffix, ":0"); + } + } + + output_graph_def->Clear(); + RenameNodeInputs(logged_graph_def, inputs_to_rename, ignore_when_renaming, + output_graph_def); + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/insert_logging_test.cc b/tensorflow/tools/graph_transforms/insert_logging_test.cc new file mode 100644 index 00000000000..e1586a46e54 --- /dev/null +++ b/tensorflow/tools/graph_transforms/insert_logging_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status InsertLogging(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class InsertLoggingTest : public ::testing::Test { + protected: + void CheckGraphCanRun(const GraphDef& graph_def, + const std::vector& output_names) { + std::unique_ptr session(NewSession(SessionOptions())); + TF_ASSERT_OK(session->Create(graph_def)); + std::vector outputs; + TF_ASSERT_OK(session->Run({}, output_names, {}, &outputs)); + } + + void TestInsertLogging() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Tensor const_tensor(DT_FLOAT, TensorShape({10})); + test::FillIota(&const_tensor, 1.0f); + Output const_node1 = + Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor)); + Output const_node2 = + Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor)); + Output const_node3 = + Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor)); + Output add_node2 = + Add(root.WithOpName("add_node2"), const_node1, const_node2); + Output add_node3 = + Add(root.WithOpName("add_node3"), const_node1, const_node3); + Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3); + Output add_node4 = + Add(root.WithOpName("add_node4"), mul_node1, const_node3); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + CheckGraphCanRun(graph_def, {"add_node4"}); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"add_node4"}; + TF_ASSERT_OK(InsertLogging(graph_def, context, &result)); + + CheckGraphCanRun(result, {"add_node4"}); + + std::unordered_set print_inputs; + for (const NodeDef& node : result.node()) { + if (node.op() == "Print") { + print_inputs.insert(node.input(0)); + } + } + + EXPECT_EQ(6, print_inputs.size()); + EXPECT_EQ(1, print_inputs.count("mul_node1:0")); + EXPECT_EQ(1, print_inputs.count("add_node2:0")); + EXPECT_EQ(1, print_inputs.count("add_node3:0")); + EXPECT_EQ(0, print_inputs.count("add_node4:0")); + EXPECT_EQ(1, print_inputs.count("const_node1:0")); + EXPECT_EQ(1, print_inputs.count("const_node2:0")); + EXPECT_EQ(1, print_inputs.count("const_node3:0")); + } + + void TestInsertLoggingByOpType() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Tensor const_tensor(DT_FLOAT, TensorShape({10})); + test::FillIota(&const_tensor, 1.0f); + Output const_node1 = + Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor)); + Output const_node2 = + Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor)); + Output const_node3 = + Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor)); + Output add_node2 = + Add(root.WithOpName("add_node2"), const_node1, const_node2); + Output add_node3 = + Add(root.WithOpName("add_node3"), const_node1, const_node3); + Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3); + Output add_node4 = + Add(root.WithOpName("add_node4"), mul_node1, const_node3); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + CheckGraphCanRun(graph_def, {"add_node4"}); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"add_node4"}; + context.params.insert( + std::pair>({"op", {"Mul", "Add"}})); + TF_ASSERT_OK(InsertLogging(graph_def, context, &result)); + + CheckGraphCanRun(result, {"add_node4"}); + + std::unordered_set print_inputs; + for (const NodeDef& node : result.node()) { + if (node.op() == "Print") { + print_inputs.insert(node.input(0)); + } + } + + EXPECT_EQ(3, print_inputs.size()); + EXPECT_EQ(1, print_inputs.count("mul_node1:0")); + EXPECT_EQ(1, print_inputs.count("add_node2:0")); + EXPECT_EQ(1, print_inputs.count("add_node3:0")); + EXPECT_EQ(0, print_inputs.count("add_node4:0")); + EXPECT_EQ(0, print_inputs.count("const_node1:0")); + EXPECT_EQ(0, print_inputs.count("const_node2:0")); + EXPECT_EQ(0, print_inputs.count("const_node3:0")); + } + + void TestInsertLoggingByPrefix() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Tensor const_tensor(DT_FLOAT, TensorShape({10})); + test::FillIota(&const_tensor, 1.0f); + Output const_node1 = + Const(root.WithOpName("const_node1"), Input::Initializer(const_tensor)); + Output const_node2 = + Const(root.WithOpName("const_node2"), Input::Initializer(const_tensor)); + Output const_node3 = + Const(root.WithOpName("const_node3"), Input::Initializer(const_tensor)); + Output add_node2 = + Add(root.WithOpName("add_node2"), const_node1, const_node2); + Output add_node3 = + Add(root.WithOpName("add_node3"), const_node1, const_node3); + Output mul_node1 = Mul(root.WithOpName("mul_node1"), add_node2, add_node3); + Output add_node4 = + Add(root.WithOpName("add_node4"), mul_node1, const_node3); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + CheckGraphCanRun(graph_def, {"add_node4"}); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"add_node4"}; + context.params.insert( + std::pair>({"prefix", {"add_node"}})); + TF_ASSERT_OK(InsertLogging(graph_def, context, &result)); + + CheckGraphCanRun(result, {"add_node4"}); + + std::unordered_set print_inputs; + for (const NodeDef& node : result.node()) { + if (node.op() == "Print") { + print_inputs.insert(node.input(0)); + } + } + + EXPECT_EQ(2, print_inputs.size()); + EXPECT_EQ(0, print_inputs.count("mul_node1:0")); + EXPECT_EQ(1, print_inputs.count("add_node2:0")); + EXPECT_EQ(1, print_inputs.count("add_node3:0")); + EXPECT_EQ(0, print_inputs.count("add_node4:0")); + EXPECT_EQ(0, print_inputs.count("const_node1:0")); + EXPECT_EQ(0, print_inputs.count("const_node2:0")); + EXPECT_EQ(0, print_inputs.count("const_node3:0")); + } +}; + +TEST_F(InsertLoggingTest, TestInsertLogging) { TestInsertLogging(); } + +TEST_F(InsertLoggingTest, TestInsertLoggingByOpType) { + TestInsertLoggingByOpType(); +} + +TEST_F(InsertLoggingTest, TestInsertLoggingByPrefix) { + TestInsertLoggingByPrefix(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index fa089b86efa..f460f31d357 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -236,7 +236,8 @@ Status MergeDuplicateNodes(const GraphDef& input_graph_def, } // Update the graph so that any nodes that referred to removed inputs now // pull from the remaining duplicate. - RenameNodeInputs(merged_graph_def, inputs_to_rename, ¤t_graph_def); + RenameNodeInputs(merged_graph_def, inputs_to_rename, + std::unordered_set(), ¤t_graph_def); } while (any_duplicates_found); *output_graph_def = current_graph_def; @@ -295,7 +296,8 @@ Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, }, {true}, &replaced_graph_def)); - RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def); + RenameNodeInputs(replaced_graph_def, inputs_to_rename, + std::unordered_set(), output_graph_def); return Status::OK(); } @@ -372,9 +374,9 @@ Status QuantizePlaceholders(const GraphDef& input_graph_def, GraphDef first_pass_graph_def; RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass, - &first_pass_graph_def); + std::unordered_set(), &first_pass_graph_def); RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass, - output_graph_def); + std::unordered_set(), output_graph_def); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc index 3290e65512a..429dbdd0b19 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes.cc @@ -80,7 +80,7 @@ Status RemoveNodes(const GraphDef& input_graph_def, {true}, &replaced_graph_def)); // Make sure all references to removed nodes now point to their inputs. RenameNodeInputs(replaced_graph_def, inputs_to_rename, - ¤t_graph_def); + std::unordered_set(), ¤t_graph_def); } while (any_nodes_removed); } diff --git a/tensorflow/tools/graph_transforms/round_weights.cc b/tensorflow/tools/graph_transforms/round_weights.cc index 6332876077c..72927e439b7 100644 --- a/tensorflow/tools/graph_transforms/round_weights.cc +++ b/tensorflow/tools/graph_transforms/round_weights.cc @@ -33,8 +33,9 @@ namespace graph_transforms { Status RoundWeights(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { - int64 num_steps; - TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps)); + int32 num_steps; + TF_RETURN_IF_ERROR( + context.GetOneInt32Parameter("num_steps", 256, &num_steps)); TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, {"Const"}, [num_steps](const NodeMatch& match, const std::set& input_nodes, diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index ecdae72c11d..310c331e8f0 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -469,6 +469,7 @@ Status ReplaceMatchingOpTypes( Status RenameNodeInputs(const GraphDef& input_graph_def, const std::map& inputs_to_rename, + const std::unordered_set& nodes_to_ignore, GraphDef* output_graph_def) { std::map>> canonical_inputs_to_rename; @@ -494,6 +495,9 @@ Status RenameNodeInputs(const GraphDef& input_graph_def, input_node_name); } already_visited.insert(input_node_name); + if (nodes_to_ignore.count(node.name())) { + break; + } bool any_match_found = false; for (const std::pair& input_to_rename : canonical_inputs_to_rename.at(input_node_name)) { @@ -630,9 +634,26 @@ Status TransformFuncContext::GetOneStringParameter(const string& name, } } -Status TransformFuncContext::GetOneIntParameter(const string& name, - int64 default_value, - int64* result) const { +Status TransformFuncContext::GetOneInt32Parameter(const string& name, + int32 default_value, + int32* result) const { + const int params_count = CountParameters(name); + if (params_count == 0) { + *result = default_value; + return Status::OK(); + } + string string_value; + TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); + if (!strings::safe_strto32(StringPiece(string_value), result)) { + return errors::InvalidArgument("Couldn't interpret the ", name, + " argument as a number:", string_value); + } + return Status::OK(); +} + +Status TransformFuncContext::GetOneInt64Parameter(const string& name, + int64 default_value, + int64* result) const { const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 2c98440907c..54808efa9fb 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -17,8 +17,10 @@ limitations under the License. #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ #include +#include #include +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -201,9 +203,11 @@ Status ReplaceMatchingOpTypes( // Returns a list of the unique nodes found in this match. void MatchedNodesAsArray(const NodeMatch& match, std::vector* result); -// Changes all input references to a particular node name. +// Changes all input references to a particular node name. Any nodes with names +// listed in nodes_to_ignore will not have their inputs rewritten. Status RenameNodeInputs(const GraphDef& input_graph_def, const std::map& inputs_to_rename, + const std::unordered_set& nodes_to_ignore, GraphDef* output_graph_def); // Utility function that copies all the nodes found in a match into the @@ -225,11 +229,17 @@ struct TransformFuncContext { Status GetOneStringParameter(const string& name, const string& default_value, string* result) const; - // Gets a single occurrence of a parameter as an integer, falling back to a - // default if it isn't present and returning an error if it isn't convertible - // to a number. - Status GetOneIntParameter(const string& name, int64 default_value, - int64* result) const; + // Gets a single occurrence of a parameter as a 32-bit integer, falling back + // to a default if it isn't present and returning an error if it isn't + // convertible to a number. + Status GetOneInt32Parameter(const string& name, int32 default_value, + int32* result) const; + + // Gets a single occurrence of a parameter as a 64-bit integer, falling back + // to a default if it isn't present and returning an error if it isn't + // convertible to a number. + Status GetOneInt64Parameter(const string& name, int64 default_value, + int64* result) const; // Gets a single occurrence of a parameter as a floating point number, falling // back to a default if it isn't present and returning an error if it isn't diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index 9e6ddb46b70..92ebc358342 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -541,7 +541,9 @@ class TransformUtilsTest : public ::testing::Test { TF_ASSERT_OK(root.ToGraphDef(&graph_def)); GraphDef renamed_graph_def; - TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def)); + TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, + std::unordered_set(), + &renamed_graph_def)); std::map node_map; MapNamesToNodes(renamed_graph_def, &node_map); @@ -579,7 +581,7 @@ class TransformUtilsTest : public ::testing::Test { GraphDef renamed_graph_def; TF_ASSERT_OK(RenameNodeInputs( graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}}, - &renamed_graph_def)); + std::unordered_set(), &renamed_graph_def)); std::map node_map; MapNamesToNodes(renamed_graph_def, &node_map); @@ -615,8 +617,9 @@ class TransformUtilsTest : public ::testing::Test { TF_ASSERT_OK(root.ToGraphDef(&graph_def)); GraphDef renamed_graph_def; - Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}}, - &renamed_graph_def); + Status rename_status = + RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}}, + std::unordered_set(), &renamed_graph_def); EXPECT_FALSE(rename_status.ok()); } @@ -650,6 +653,7 @@ class TransformUtilsTest : public ::testing::Test { GraphDef renamed_graph_def; TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}}, + std::unordered_set(), &renamed_graph_def)); std::map node_map; @@ -658,6 +662,45 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1)); } + void TestRenameNodeInputsWithIgnores() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add"), a_const, a_const); + + Output add2 = Add(root.WithOpName("add2"), a_const, a_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("mul"), add, placeholder); + + Output mul2 = Mul(root.WithOpName("output"), mul, add2); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef renamed_graph_def; + TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, {"add2"}, + &renamed_graph_def)); + + std::map node_map; + MapNamesToNodes(renamed_graph_def, &node_map); + EXPECT_EQ("b", node_map.at("add")->input(0)); + EXPECT_EQ("b", node_map.at("add")->input(1)); + EXPECT_EQ("a", node_map.at("add2")->input(0)); + EXPECT_EQ("a", node_map.at("add2")->input(1)); + } + void TestFindInvalidInputs() { GraphDef graph_def; @@ -943,20 +986,36 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_EQ("d", value); } - void TestGetOneIntParameter() { + void TestGetOneInt32Parameter() { + TransformFuncContext context; + context.params.insert({"foo", {"10", "20"}}); + context.params.insert({"bar", {"-23"}}); + context.params.insert({"not_a_number", {"not_numerical"}}); + context.params.insert({"float", {"-23.232323"}}); + int32 value; + TF_EXPECT_OK(context.GetOneInt32Parameter("bar", 0, &value)); + EXPECT_EQ(-23, value); + EXPECT_FALSE(context.GetOneInt32Parameter("foo", 0, &value).ok()); + TF_EXPECT_OK(context.GetOneInt32Parameter("not_present", 10, &value)); + EXPECT_EQ(10, value); + EXPECT_FALSE(context.GetOneInt32Parameter("not_a_number", 0, &value).ok()); + EXPECT_FALSE(context.GetOneInt32Parameter("float", 0, &value).ok()); + } + + void TestGetOneInt64Parameter() { TransformFuncContext context; context.params.insert({"foo", {"10", "20"}}); context.params.insert({"bar", {"-23"}}); context.params.insert({"not_a_number", {"not_numerical"}}); context.params.insert({"float", {"-23.232323"}}); int64 value; - TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value)); + TF_EXPECT_OK(context.GetOneInt64Parameter("bar", 0, &value)); EXPECT_EQ(-23, value); - EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok()); - TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value)); + EXPECT_FALSE(context.GetOneInt64Parameter("foo", 0, &value).ok()); + TF_EXPECT_OK(context.GetOneInt64Parameter("not_present", 10, &value)); EXPECT_EQ(10, value); - EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok()); - EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok()); + EXPECT_FALSE(context.GetOneInt64Parameter("not_a_number", 0, &value).ok()); + EXPECT_FALSE(context.GetOneInt64Parameter("float", 0, &value).ok()); } void TestGetOneFloatParameter() { @@ -1111,6 +1170,10 @@ TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) { TestRenameNodeInputsWithWildcard(); } +TEST_F(TransformUtilsTest, TestRenameNodeInputsWithIgnores) { + TestRenameNodeInputsWithIgnores(); +} + TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); } TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); } @@ -1127,7 +1190,13 @@ TEST_F(TransformUtilsTest, TestGetOneStringParameter) { TestGetOneStringParameter(); } -TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); } +TEST_F(TransformUtilsTest, TestGetOneInt32Parameter) { + TestGetOneInt32Parameter(); +} + +TEST_F(TransformUtilsTest, TestGetOneInt64Parameter) { + TestGetOneInt64Parameter(); +} TEST_F(TransformUtilsTest, TestGetOneFloatParameter) { TestGetOneFloatParameter();