diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 3888a3e1a01..861b114e7e7 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -252,6 +252,16 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, bool DoConstantFolding(const ConstantFoldingOptions& opts, FunctionLibraryRuntime* function_library, Env* env, Device* partition_device, Graph* graph) { + bool was_mutated; + Status unused_status = DoConstantFoldingWithStatus( + opts, function_library, env, partition_device, graph, &was_mutated); + return was_mutated; +} + +Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts, + FunctionLibraryRuntime* function_library, + Env* env, Device* partition_device, + Graph* graph, bool* was_mutated) { DumpGraph("Before", graph); const FunctionLibraryDefinition* flib_def = nullptr; @@ -263,7 +273,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes); if (constant_foldable_nodes.empty()) { VLOG(1) << "No constant foldable nodes found"; - return false; + *was_mutated = false; + // This is not an error, so return the status as OK. + return Status::OK(); } std::map tensors_to_fetch; @@ -273,7 +285,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, if (tensors_to_fetch.empty()) { VLOG(1) << "No constant nodes found that feed into the original graph."; - return false; + *was_mutated = false; + // This is not an error, so return the status as OK. + return Status::OK(); } VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : " << graph->num_node_ids(); @@ -292,7 +306,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, {} /* inputs*/, tensors_to_fetch_names, &outputs); if (!s.ok()) { VLOG(1) << "Could not fetch constants: " << s; - return false; + *was_mutated = false; + // This is not an error, so return the status as OK. + return s; } // Fetch the constant tensors and replace the corresponding tensors in the @@ -307,7 +323,8 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, DumpGraph("After", graph); - return num_nodes_replaced > 0; + *was_mutated = (num_nodes_replaced > 0); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index 0895a8493f2..9e3479e50b0 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -29,7 +29,16 @@ namespace tensorflow { // and replaces those nodes with the result of the evaluation. // "partition_device", if non-null, is the device where all the graph nodes are // assumed to execute. -// Returns true if and only if "graph" has been mutated. +// Sets `was_mutated` to true if and only if "graph" has been mutated. +// The status is only set to a non-OK state if an unexpected error is hit +// running the graph. +Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts, + FunctionLibraryRuntime* function_library, + Env* env, Device* partition_device, + Graph* graph, bool* was_mutated); + +// Version of the function that doesn't return a Status, for backwards +// compatibility. bool DoConstantFolding(const ConstantFoldingOptions& opts, FunctionLibraryRuntime* function_library, Env* env, Device* partition_device, Graph* graph); diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 2e1dc16c166..17001481722 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -228,8 +228,12 @@ TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) { g->AddControlEdge(concat_send, g->sink_node()); // The above concat should not have been constant folded. - EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, - Env::Default(), nullptr, g)); + bool was_mutated; + Status status = + DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr, + Env::Default(), nullptr, g, &was_mutated); + EXPECT_FALSE(was_mutated); + TF_EXPECT_OK(status); } TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) { @@ -257,8 +261,12 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) { g->AddControlEdge(times_two_send, g->sink_node()); // The above function call should not have been constant folded. - EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, - Env::Default(), nullptr, g)); + bool was_mutated; + status = + DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr, + Env::Default(), nullptr, g, &was_mutated); + EXPECT_FALSE(was_mutated); + EXPECT_FALSE(status.ok()); g_ = nullptr; } @@ -337,10 +345,16 @@ TEST_F(ConstantFoldingTest, TestImmutableConst) { auto result2 = ops::MatMul(root, result1, c); TF_ASSERT_OK(root.ToGraph(g)); TestTFEnvironment test_env; - EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, - Env::Default(), nullptr, g)); - EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, &test_env, - nullptr, g)); + bool was_mutated; + Status status = + DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr, + Env::Default(), nullptr, g, &was_mutated); + EXPECT_FALSE(was_mutated); + EXPECT_FALSE(status.ok()); + status = DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr, + &test_env, nullptr, g, &was_mutated); + EXPECT_TRUE(was_mutated); + TF_EXPECT_OK(status); } } // namespace diff --git a/tensorflow/core/kernels/quantize_op.cc b/tensorflow/core/kernels/quantize_op.cc index 458bede485d..7b34c32cebd 100644 --- a/tensorflow/core/kernels/quantize_op.cc +++ b/tensorflow/core/kernels/quantize_op.cc @@ -41,11 +41,12 @@ template class QuantizeV2Op : public OpKernel { public: explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { - half_range_ = !std::is_signed::value - ? 0.0f - : (std::numeric_limits::max() - - std::numeric_limits::min() + 1) / - 2.0f; + half_range_ = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + static_cast(std::numeric_limits::min()) + 1) / + 2.0f; string mode_string; OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); OP_REQUIRES(ctx, @@ -90,7 +91,8 @@ class QuantizeV2Op : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); if (mode_ == QUANTIZE_MODE_MIN_COMBINED) { const float scale_factor = - (std::numeric_limits::max() - std::numeric_limits::min()) / + (static_cast(std::numeric_limits::max()) - + static_cast(std::numeric_limits::min())) / (max_range - min_range); // Quantize: @@ -162,5 +164,8 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), QuantizeV2Op); +REGISTER_KERNEL_BUILDER( + Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint("T"), + QuantizeV2Op); } // namespace tensorflow diff --git a/tensorflow/core/kernels/quantize_op_test.cc b/tensorflow/core/kernels/quantize_op_test.cc index 670225c0dc0..41996852f16 100644 --- a/tensorflow/core/kernels/quantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_op_test.cc @@ -47,6 +47,46 @@ TEST_F(QuantizedOpTest, QuantizeV2) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(QuantizedOpTest, QuantizeV2_32Bit) { + TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DataTypeToEnum::v()) + .Attr("mode", "MIN_FIRST") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + const int element_count = 8; + AddInputFromArray( + TensorShape({element_count}), + {-500.0f, 0.0f, 1.0f, 1.25f, 1.75f, 127.0f, 255.0f, 500.0f}); + AddInputFromArray(TensorShape({1}), {-256.0f}); + AddInputFromArray(TensorShape({1}), {256.0f}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_QINT32, TensorShape({element_count})); + test::FillValues(&expected, + { + std::numeric_limits::min(), 0, + static_cast(1.0f * (1 << 23)), + static_cast(1.25f * (1 << 23)), + static_cast(1.75f * (1 << 23)), + static_cast(127.0f * (1 << 23)), + static_cast(255.0f * (1 << 23)), + std::numeric_limits::max(), + }); + // We expect there will be some fuzziness in the lower bits, since this is + // converting from float. + const int64 epsilon = 1 << 8; + const qint32* output_data = GetOutput(0)->flat().data(); + const qint32* expected_data = expected.flat().data(); + for (int i = 0; i < element_count; ++i) { + const int64 delta = output_data[i] - expected_data[i]; + EXPECT_GT(epsilon, std::abs(delta)) + << "output_data[" << i << "]=" << output_data[i] << ", expected_data[" + << i << "]=" << expected_data[i] << ", delta=" << delta; + } +} + TEST_F(QuantizedOpTest, QuantizeV2Ports) { TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") .Input(FakeInput(DT_FLOAT)) diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 02500cccdc7..0e8da270170 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -29,6 +29,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", ], @@ -52,9 +53,23 @@ tf_cc_test( ) cc_library( - name = "fold_constants_lib", + name = "transforms_lib", srcs = [ + "fold_batch_norms.cc", "fold_constants_lib.cc", + "fold_old_batch_norms.cc", + "fuse_convolutions.cc", + "obsfucate_names.cc", + "quantize_nodes.cc", + "quantize_weights.cc", + "remove_attribute.cc", + "remove_device.cc", + "remove_nodes.cc", + "rename_attribute.cc", + "rename_op.cc", + "round_weights.cc", + "sort_by_execution_order.cc", + "strip_unused_nodes.cc", ], hdrs = [ "fold_constants_lib.h", @@ -65,20 +80,98 @@ cc_library( ":transform_utils", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels:quantized_ops", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "transforms_test", + size = "small", + srcs = [ + "fold_batch_norms_test.cc", + "fold_constants_test.cc", + "fold_old_batch_norms_test.cc", + "fuse_convolutions_test.cc", + "obsfucate_names_test.cc", + "quantize_nodes_test.cc", + "quantize_weights_test.cc", + "remove_attribute_test.cc", + "remove_device_test.cc", + "remove_nodes_test.cc", + "rename_attribute_test.cc", + "rename_op_test.cc", + "round_weights_test.cc", + "sort_by_execution_order_test.cc", + "strip_unused_nodes_test.cc", + ], + deps = [ + ":transform_utils", + ":transforms_lib", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:quantized_ops", + ], +) + +cc_library( + name = "transform_graph_lib", + srcs = ["transform_graph.cc"], + hdrs = ["transform_graph.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":transform_utils", + ":transforms_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +# This library includes a main function, to make it easy to create other +# versions of the tool linked against different operator libs. +cc_library( + name = "transform_graph_main_lib", + srcs = ["transform_graph_main.cc"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":transform_graph_lib", + ":transforms_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_binary( + name = "transform_graph", + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":transform_graph_main_lib", ], ) tf_cc_test( - name = "fold_constants_test", - size = "small", - srcs = ["fold_constants_test.cc"], + name = "transform_graph_test", + size = "medium", + srcs = ["transform_graph_test.cc"], deps = [ - ":fold_constants_lib", + ":transform_graph_lib", ":transform_utils", "//tensorflow/cc:cc_ops", "//tensorflow/cc:sendrecv_ops", @@ -95,23 +188,24 @@ tf_cc_test( # This library includes a main function, to make it easy to create other # versions of the tool linked against different operator libs. cc_library( - name = "fold_constants_main_lib", - srcs = ["fold_constants_tool.cc"], + name = "summarize_graph_main_lib", + srcs = ["summarize_graph_main.cc"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - ":fold_constants_lib", + ":transform_utils", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], ) cc_binary( - name = "fold_constants_tool", + name = "summarize_graph", copts = tf_copts(), linkstatic = 1, visibility = ["//visibility:public"], deps = [ - ":fold_constants_main_lib", + ":summarize_graph_main_lib", ], ) diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md new file mode 100644 index 00000000000..7e5d52f3337 --- /dev/null +++ b/tensorflow/tools/graph_transforms/README.md @@ -0,0 +1,858 @@ +# Graph Transform Tool + +## Table of Contents + +* [Introduction](#introduction) +* [Using the Graph Transform Tool](#using-the-graph-transform-tool) +* [Inspecting Graphs](#inspecting-graphs) +* [Common Use Cases](#common-use-cases) + * [Optimizing for Deployment](#optimizing-for-deployment) + * [Fixing Missing Kernel Errors on + Mobile](#fixing-missing-kernel-errors-on-mobile) + * [Shrinking File Size](#shrinking-file-size) + * [Eight-bit Calculations](#eight-bit-calculations) +* [Transform Reference](#transform-reference) + * [fold_batch_norms](#fold_batch_norms) + * [fold_constants](#fold_constants) + * [fold_old_batch_norms](#fold_old_batch_norms) + * [fuse_convolutions](#fuse_convolutions) + * [merge_duplicate_nodes](#merge_duplicate_nodes) + * [obsfucate_names](#obsfucate_names) + * [quantize_nodes](#quantize_nodes) + * [quantize_weights](#quantize_weights) + * [remove_attribute](#remove_attribute) + * [remove_device](#remove_device) + * [remove_nodes](#remove_nodes) + * [rename_attribute](#rename_attribute) + * [rename_op](#rename_op) + * [round_weights](#round_weights) + * [sort_by_execution_order](#sort_by_execution_order) + * [strip_unused_nodes](#strip_unused_nodes) +* [Writing Your Own Transforms](#writing-your-own-transforms) + * [Transform Functions](#transform-functions) + * [Pattern Syntax](#pattern-syntax) + * [ReplaceMatchingOpTypes](#replacematchingoptypes) + * [Parameters](#parameters) + * [Function Libraries](#function-libraries) + * [Registering](#registering) + +## Introduction + +When you have finished training a model and want to deploy it in production, +you'll often want to modify it to better run in its final environment. For +example if you're targeting a phone you might want to shrink the file size by +quantizing the weights, or optimize away batch normalization or other +training-only features. The Graph Transform framework offers a suite of tools +for modifying computational graphs, and a framework to make it easy to write +your own modifications. + +This guide is structured into three main parts, first giving some tutorials on +how to perform common tasks, second a reference covering all of the different +transformations that are included, together with the options that apply to them, +and third a guide to creating your own transforms. + +## Using the Graph Transform Tool + +The Graph Transform tool is designed to work on models that are saved as +GraphDef files, usually in a binary protobuf format. This is the low-level +definition of a TensorFlow computational graph, including a list of nodes and +the input and output connections between them. If you're using a Python API to +train your model, this will usually be saved out in the same directory as your +checkpoints, and usually has a '.pb' suffix. + +If you want to work with the values of your trained parameters, for example to +quantize weights, you'll need to run +[tensorflow/python/tools/freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py) +to convert the checkpoint values into embedded constants within the graph file +itself. + +You call the Graph Transform tool itself like this: + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +strip_unused_nodes(type=float, shape="1,299,299,3") \ +remove_nodes(op=Identity, op=CheckNumerics) \ +fold_old_batch_norms \ +' +``` + +The arguments here are specifying where to read the graph from, where to write +the transformed version to, what the input and output layers are, and what +transforms to modify the graph with. The transforms are given as a list of +names, and can each have arguments themselves. These transforms define the +pipeline of modifications that are applied in order to produce the output. +Sometimes you need some transforms to happen before others, and the ordering +within the list lets you specify which happen first. + +## Inspecting Graphs + +Many of the transforms that the tool supports need to know what the input and +output layers of the model are. The best source for these is the model training +process, where for a classifier the inputs will be the nodes that receive the +data from the training set, and the output will be the predictions. If you're +unsure, the +[summarize_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph.cc) +can inspect the model and provide guesses about likely input and output nodes, +as well as other information that's useful for debugging. Here's an example of +how to use it on the [Inception V3 +graph](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz): + +```bash +bazel build tensorflow/tools/graph_transforms:summarize_graph +bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=tensorflow_inception_graph.pb +``` + +## Common Use Cases + +This section has small guides for some of the most frequently-used +transformation pipelines, aimed at users who want to quickly accomplish one of +these tasks. A lot of them will use the Inception V3 model for their examples, +which can be downloaded from +[http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz). + +### Optimizing for Deployment + +If you've finished training your model and want to deploy it on a server or a +mobile device, you'll want it to run as fast as possible, and with as few +non-essential dependencies as you can. This recipe removes all of the nodes that +aren't called during inference, shrinks expressions that are always constant +into single nodes, and optimizes away some multiply operations used during batch +normalization by pre-multiplying the weights for convolutions. + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +strip_unused_nodes(type=float, shape="1,299,299,3") \ +remove_nodes(op=Identity, op=CheckNumerics) \ +fold_constants(ignore_errors=true) \ +fold_batch_norms \ +fold_old_batch_norms\ +' +``` + +The batch norm folding is included twice because there are two different flavors +of batch normalization used in TensorFlow. The older version was implemented +with a single BatchNormWithGlobalNormalization op, but it was deprecated in +favor of a more recent approach using individual ops to implement the same +computation. The two transforms are in there so that both styles are recognized +and optimized. + +### Fixing Missing Kernel Errors on Mobile + +The mobile version of TensorFlow is focused on inference, and so by default the +list of supported ops (defined in +[tensorflow/core/kernels/BUILD:android_extended_ops](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/BUILD#L2452) +for Bazel and +[tensorflow/contrib/makefile/tf_op_files.txt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/tf_op_files.txt) +for make builds) doesn't include a lot that are training related. This can cause +`No OpKernel was registered to support Op` errors when a GraphDef is loaded, +even if the op isn't going to be executed. + +If you see this error and it's an op that you do actually want to run on mobile, +then you'll need to make local modifications to the build files to include the +right .cc file that defines it. In a lot of cases the op is just a vestigial +remnant from the training process though, and if that's true then you can run +the [strip_unused_nodes](#strip_unused_nodes), specifying the inputs and outputs +of your inference usage, to remove those unneccessary nodes: + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +strip_unused_nodes(type=float, shape="1,299,299,3") \ +fold_constants \ +fold_batch_norms \ +fold_old_batch_norms\ +' +``` + +### Shrinking File Size + +If you're looking to deploy your model as part of a mobile app, then keeping the +download size as small as possible is important. For most TensorFlow models, the +largest contributors to the file size are the weights passed in to convolutional +and fully-connected layers, so anything that can reduce the storage size for +those is very useful. Luckily most neural networks are resistant to noise, so +it's possible to change the representation of those weights in a lossy way +without losing very much accuracy overall. + +On both iOS and Android app packages are compressed before download, so the +simplest way to reduce the bandwidth your users need to receive your app is to +provide raw data that compresses more easily. By default the weights are stored +as floating-point values, and even tiny differences between numbers result in +very different bit patterns, and so these don't compress very well. If you round +the weights so that nearby numbers are stored as exactly the same values, the +resulting bit stream has a lot more repetition and so compresses down a lot more +effectively. To try this technique on your model, run the +[round_weights](#round_weights] transform. + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +round_weights(num_steps=256) \ +' +``` + +You should see that the `optimized_inception_graph.pb` output file is the same +size as the input, but if you run zip on it to compress it, it's almost 70% +smaller than if you zip the original! The nice thing about this transform is +that it doesn't change the structure of the graph at all, so it's running +exactly the same operations and should have the same latency and memory usage as +before. You can adjust the `num_steps` parameter to control how many values each +weight buffer is rounded to, so lower numbers will increase the compression at +the cost of accuracy. + +As a further step, you can store the weights into eight-bit values directly. +Here's the recipe for that: + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +quantize_weights \ +' +``` + +You should see that the size of the output graph is about a quarter of the +original. The downside to this approach compared to round_weights is that extra +decompression ops are inserted to convert the eight-bit values back into +floating point, but optimizations in TensorFlow's runtime should ensure these +results are cached and so you shouldn't see the graph run any more slowly. + +So far we've been concentrating on weights because those generally take up the +most space. If you have a graph with a lot of small nodes in it, the names of +those nodes can start to take up a noticeable amount of space too. To shrink +those down, you can run the [obsfucate_names](#obsfucate_names) transform, which +replaces all the names (except for inputs and outputs) with short, cryptic but +unique ids: + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +obsfucate_names \ +' +``` + +### Eight-bit Calculations + +For some platforms it's very helpful to be able to do as many calculations as +possible in eight-bit, rather than floating-point. The support for this in +TensorFlow is still experimental and evolving, but you can convert models into +quantized form using the graph transform tool: + +```bash +bazel build tensorflow/tools/graph_transforms:transform_graph +bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ +--in_graph=tensorflow_inception_graph.pb \ +--out_graph=optimized_inception_graph.pb \ +--inputs='Mul:0' \ +--outputs='softmax:0' \ +--transforms='\ +strip_unused_nodes(type=float, shape="1,299,299,3") \ +remove_nodes(op=Identity, op=CheckNumerics) \ +fold_old_batch_norms \ +quantize_weights \ +quantize_nodes \ +strip_unused_nodes \ +' +``` + +This process converts all the operations in the graph that have quantized +equivalents, and leaves the rest in floating point. Only a subset of ops are +supported, and on many platforms the quantized code may actually be slower than +the float equivalents, but this is a way of increasing performance substantially +when all the circumstances are right. + +A full guide to optimizing for quantization is beyond the scope of this guide, +but one thing that can help is using the FakeQuantWithMinMaxVars op after Conv2D +or similar operations during training. This trains the min/max variables that +control the range used for quantization, so that the range doesn't have to be +calculated dynamically by RequantizationRange during inference. + +## Transform Reference + +The transforms string is parsed as a series of transform names, each of which +can have multiple named arguments inside parentheses. Arguments are separated by +commas, and double-quotes (") can be used to hold argument values if they +themselves contain commas (for example shape definitions). + +The --inputs and --outputs are shared across all transforms, since it's common +to need to know what the ingoing and outgoing nodes in the graph are. You should +make sure you set these correctly before calling the graph transform tool, and +if you're in doubt check with the model's author, or use the `check_graph` tool +to examine likely inputs and outputs. + +All transforms can be passed the `ignore_errors` flag, with the value set to +either true or false. By default any errors that happen within a transform will +abort the whole process, but if you enable this then an error will just be +logged and the transform skipped. This is especially useful for optional +transforms where version errors or other unimportant problems may trigger an +error. + +### fold_batch_norms + +Args: None + +This transform tries to optimize away the Mul that's introduced after a Conv2D +when batch normalization has been used during training. It scans the graph for +any channel-wise multiplies immediately after convolutions, and multiplies the +convolution's weights with the mul instead so it can be omitted. You'll need to +make sure you run [fold_constants](#fold_constants) first, since the pattern can +only be spotted if the normal complex expression that's produced by training for +the Mul input is collapsed down into a simple constant. + +### fold_constants + +Args: None + +Looks for any sub-graphs within the model that always evaluate to constant +expressions, and replaces them with those constants. This optimization is always +executed at run-time after the graph is loaded, so running it offline first +won't help latency, but it can simplify the graph and so make further processing +easier. It's often useful to call this with `fold_constants(ignore_errors=true)` +to continue on past transient errors, since this is just an optimization phase. + +### fold_old_batch_norms + +Args: None + +In the early days of TensorFlow, batch normalization was implemented using a +single monolithic `BatchNormWithGlobalNormalization` op. In modern versions, +adding batch normalization from Python will give you a series of smaller math +ops instead, to achieve the same effect without special-purpose code. If you +have a graph that uses the older-style, this transform will recognize and +optimize those ops for inference, in the same way that the +[fold_batch_norms](#fold_batch_norms) transform does for the new approach. + +### fuse_convolutions + +Args: None + +For graphs that use ResizeBilinear or MirrorPad ops before convolutions, +typically to scale up in the later stages of an image style transfer model for +example, it can save on memory usage and latency to combine the spatial +transformations with the convolution's im2col patch generation. This transform +looks out for that particular pattern of ops and replaces them with a fused +version that combines the resizing and padding with the convolution. + +### merge_duplicate_nodes + +Args: None + +If there are Const nodes with the same types and contents, or nodes with the +same inputs and attributes, this transform will merge them together. It can be +useful when you want to cut down the number of nodes in a graph that has a lot +of redundancy, and is always run as part of [quantize_nodes](#quantize_nodes) +since the processing there can introduce duplicates of constants that are used +in the quantize/dequantize process. + +### obsfucate_names + +Args: None + +Replaces all node's names with short generated ids, other than the inputs and +outputs. This also updates all references within the graph so that the structure +is preserved. This can be useful if you want to shrink the file size, or if you +want to make it harder to understand the architecture of your model before +releasing it. + +### quantize_nodes + +Args: + +* input_min: The lowest float value for any quantized placeholder inputs. +* input_max: The highest float value for any quantized placeholder inputs. If + both input_min and input_max are set, then any float placeholders in the + graph will be replaced with quantized versions, and consts will be created + to pass the range to subsequent operations. +* fallback_min: The lowest float value to use for requantizing activation + layers. +* fallback_max: The highest float value to use for requantizing activation + layers. If both fallback_min and fallback_max are set, then instead of using + RequantizationRange ops ro figure out the useful range dynamically when + converting the 32-bit output of ops like QuantizedConv2D and + QuantizedBiasAdd, hardwired consts with these values will be used instead. + This can help performance, if you know the range of your activation layers + ahead of time. + +Replaces any calculation nodes with their eight-bit equivalents, and adds in +conversion layers to allow remaining float operations to interoperate. This is +one of the most complex transforms, and involves multiple passes and a lot of +rewriting. It's also still an active area of research, so results may vary +depending on the platform and operations you're using in your model. You should +run [quantize_weights](#quantize_weights) first to ensure your Const ops are in +eight-bit form. + +### quantize_weights + +Args: None + +Converts any large (more than 15 element) float Const op into an eight-bit +equivalent, followed by a float conversion op so that the result is usable by +subsequent nodes. This is mostly useful for [shrinking file +sizes](#shrinking-file-size), but also helps with the more advanced +[quantize_nodes](#quantize_nodes) transform. + +### remove_attribute + +Args: + +* attribute_name: Name of the attribute you want to remove. +* op_name: Optional name of a single op to restrict the removal to. + +Deletes the given attribute from either all nodes, or just the one specified in +`op_name`. This can be a dangerous transform since it's easy to leave your graph +in an invalid state if you remove a required attribute. It can be useful in +special circumstances though. + +### remove_device + +Args: None + +All ops can have a hardware device specified. This can be a problem when you're +loading a graph on a different system than the model was trained on, since some +specified devices may not be available. In order to work with graphs like these, +you can run this transform to wipe the slate clean and delete the device +specifier from all ops. + +### remove_nodes + +Args: + +* op: The name of the op you want to remove. Can be repeated to remove + multiple ops. + +This is a potentially dangerous transform that looks for single-input, +single-output ops with the given names, removes them from the graph, and rewires +all inputs that use to pull from them to pull from the preceding node instead. +This is most useful for getting rid of ops like `CheckNumerics` that are useful +during training but just complicate the graph and increase latency during +inference. It's dangerous because it's possible that removing some ops may +change the output of your graph, so make sure you check the overall accuracy +after using this. + +### rename_attribute + +Args: + +* old_attribute_name: Current name of the attribute you want to rename. +* new_attribute_name: Name that you want the attribute to have now. +* op_name: If this is set, only change attributes for a given op type, + otherwise apply to all nodes with attribute names that match. + +Changes the name of the given attribute. This is often useful for upgrading +graph files as op definitions change over versions, since the renaming is often +enough to deal with minor changes. + +### rename_op + +Args: + +* old_op_name: Current name of the operation. +* new_op_name: Name to change to. + +Finds all ops with the given name, and changes them to the new one. This can be +useful for version upgrading if the changes between ops are minor apart from the +name. + +### round_weights + +Args: + +* num_steps: How many unique values to use in each buffer. + +Rounds all float values in large Const ops (more than 15 elements) to the given +number of steps. The unique values are chosen per buffer by linearly allocating +between the largest and smallest values present. This is useful when you'll be +deploying on mobile, and you want a model that will compress effectively. See +[shrinking file size](#shrinking-file-size) for more details. + +### sort_by_execution_order + +Args: None + +Arranges the nodes in the GraphDef in topological order, so that the inputs of +any given node are always earlier than the node itself. This is especially +useful when you're targeting a minimal inference engine, since you can just +execute the nodes in the given order knowing that the inputs will be computed +before they're needed. + +### strip_unused_nodes + +Args: + +* type: Default type for any new Placeholder nodes generated, for example + int32, float, quint8. +* shape: Default shape for any new Placeholder nodes generated, as + comma-separated dimensions. For example shape="1,299,299,3". The double + quotes are important, since otherwise the commas will be taken as argument + separators. +* name: Identifier for the placeholder arguments. +* type_for_name: What type to use for the previously-given name. +* shape_for_name: What shape to use for the previously-given name. + +Removes all nodes not used in calculated the layers given in `--outputs`, fed by +`--inputs`. This is often useful for removing training-only nodes like +save-and-restore or summary ops. It's also handy for solving the [missing kernel +errors problem](#fixing-missing-kernel-errors-on-mobile) when there are decode +or other ops you don't need in the inference path. + +The biggest complication is that it sometimes has to create new Placeholder ops, +so there are options to control their characteristics. This will happen if you +bypass a DecodeJpeg op by specifying an input layer deeper in the network, for +example, so you can pass in a raw image array instead of an encoded string as an +input. The decode op will be removed, together with the Placeholder that fed it, +but a new Placeholder is needed for the input layer you specify. The type and +shape arguments let you control the attributes of any new Placeholders that are +created. Plain `type` and `shape` set global defaults, but if you have different +inputs with varying characteristics, you'll need to pass in a list of arguments +where the preceding name specifies what layer each applies to. For example, if +you had two inputs in1 and in2, you could call `strip_unused_node(name=in1, +type_for_name=int32, shape_for_name="2,3", name=in2, type_for_name=float, +shape_for_name="1,10,10,3")`. + +## Writing Your Own Transforms + +The Graph Transform Tool is designed to make it as easy as possible to create +your own optimization, modification, and pre-processing transforms. At their +heart, all of the transforms take in a valid GraphDef, make some changes, and +output a new GraphDef. Each GraphDef is just a list of NodeDefs, each defining +one node in the graph and its connections. You can find more information on the +format at [this guide to TensorFlow model +files](https://www.tensorflow.org/versions/master/how_tos/tool_developers/index.html), +but for a simple example take a look at +[tensorflow/tools/graph_transforms/rename_op.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/rename_op.cc), +which implements the [rename_op](#rename_op) transform: + +```C++ +Status RenameOp(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + if (!context.params.count("old_op_name") || + (context.params.at("old_op_name").size() != 1) || + !context.params.count("new_op_name") || + (context.params.at("new_op_name").size() != 1)) { + return errors::InvalidArgument( + "remove_nodes expects exactly one 'old_op_name' and 'new_op_name' " + "argument, e.g. rename_op(old_op_name=Mul, new_op_name=Multiply)"); + } + + const string old_op_name = context.params.at("old_op_name")[0]; + const string new_op_name = context.params.at("new_op_name")[0]; + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + if (node.op() == old_op_name) { + new_node->set_op(new_op_name); + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("rename_op", RenameOp); +``` + +The heart of this transform is the loop through the input_graph_def's nodes. We +go through each op, add a new one to the output, copy the original's contents, +and then change the op over if it matches the parameters. There's a standard set +of parameters for every transform, so they all take in a GraphDef and context, +and write out into a new GraphDef. The registration macro at the bottom lets the +tool know what function to call when it finds the `rename_op` string in a +transforms list. + +### Transform Functions + +The standard signature that all transform functions have is defined as +`TransformFunc`, which takes in an input GraphDef, a `TransformFuncContext` +containing environment information, writes to an output GraphDef, and returns a +Status indicating whether the transform succeeded. + +The `TransformFuncContext` has a list of the inputs and outputs for the graph, +and the [parameter arguments](#parameters) that were passed into the transform +by the user. + +If you write a function that matches this signature, and [register +it](#registration), the graph transform tool will take care of calling it. + +### Pattern Syntax + +The `rename_op` example only needs to look at a single node at a time, but one +of the most common needs is to modify small sub-graphs within a model. To make +this easy, the Graph Transform Tool provides the `OpTypePattern` syntax. This is +a simple and compact way to specify patterns of nodes that you want to look for. +For example, if you want all Conv2D nodes that have a constant as their second +input, you would set up a pattern like this, using C++ initializer lists to +populate the structure: + +```C++ +OpTypePattern conv_pattern({"Conv2D", {{"*"}, {"Const"}}}); +``` + +It can be easier to visualize these initializers using indentation to show the +tree structure more clearly: + +```C++ +OpTypePattern conv_pattern({ + "Conv2D", + { + {"*"}, + {"Const"} + } +}); +``` + +In plain English this is saying, a Conv2D op with two inputs, the first of which +is any op type, and the second is a Const op. + +The op field can either contain a single "*", which means match any op type, one +op type (for example "Const"), or a set of op types separated by `|` symbols +(for example "Conv2D|MatMul|BiasAdd"). General regex patterns are not supported, +just these special cases. + +You can think of these patterns as very limited regular expressions designed to +pick out sub-trees in graphs. They are deliberately very constrained to the kind +of things we commonly find ourselves needing to do, to make creating and +debugging as straightforward as possible. + +Here's a much more complex example, from the [quantize_nodes](#quantize_nodes) +transform: + +```C++ +{"QuantizeV2", + { + {"Dequantize"}, + {"Min", + { + {"Reshape", + { + {"Dequantize"}, + {"Const"}, + } + }, + {"Const"}, + } + }, + {"Max", + { + {"Reshape", + { + {"Dequantize"}, + {"Const"}, + } + }, + {"Const"}, + } + }, + } +} +``` + +This is looking for QuantizeV2 nodes, with three inputs, the first of which is a +Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the +third is a Max which does the same. We know the end result of this sub-graph is +a no-op, since it's just turning an eight-bit buffer into float, and then +immediately converting it back to eight-bits, so if we look for this pattern and +remove it we can optimize the graph without changing the result. + +### ReplaceMatchingOpTypes + +It's very common to want to find all occurrences of a particular sub-graph in a +model, and replace them all with a different sub-graph that keeps the same local +input and output connections. For example with +[fuse_convolutions](#fuse_convolutions), we needed to find all Conv2D ops that +read their inputs from BilinearResizes, and replace those combinations with a +single FusedResizeAndPadConv2D op, but without affecting other ops. + +To make that sort of transformation easy, we created the +`ReplaceMatchingOpTypes` helper. This takes in a graph, an `OpTypePattern` +defining the sub-graph to look for, and a callback function to run for every +occurrence it finds. The job of this callback function is to look at the +`NodeMatch` that contains information about the current sub-graph, and return a +new sub-graph in the new_nodes list that will be used to replace the old +sub-graph. + +You can see how it's used in practice in the +[fuse_convolutions](#fuse_convolutions) code: + +```C++ +TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"Conv2D", + { + {"ResizeBilinear"}, + {"*"} + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + // Find all the nodes we expect in the subgraph. + const NodeDef& conv_node = match.node; + const NodeDef& resize_node = match.inputs[0].node; + const NodeDef& weights_node = match.inputs[1].node; + + // We'll be reusing the old weights. + new_nodes->push_back(weights_node); + + // Create a 'no-op' mirror padding node that has no effect. + NodeDef pad_dims_node; + pad_dims_node.set_op("Const"); + pad_dims_node.set_name(conv_node.name() + "_dummy_paddings"); + SetNodeAttr("dtype", DT_INT32, &pad_dims_node); + SetNodeTensorAttr("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0}, + &pad_dims_node); + new_nodes->push_back(pad_dims_node); + + // Set up the new fused version of the convolution op. + NodeDef fused_conv; + fused_conv.set_op("FusedResizeAndPadConv2D"); + fused_conv.set_name(match.node.name()); + AddNodeInput(resize_node.input(0), &fused_conv); + AddNodeInput(resize_node.input(1), &fused_conv); + AddNodeInput(pad_dims_node.name(), &fused_conv); + AddNodeInput(conv_node.input(1), &fused_conv); + CopyNodeAttr(resize_node, "align_corners", "resize_align_corners", + &fused_conv); + SetNodeAttr("mode", "REFLECT", &fused_conv); + CopyNodeAttr(conv_node, "T", "T", &fused_conv); + CopyNodeAttr(conv_node, "padding", "padding", &fused_conv); + CopyNodeAttr(conv_node, "strides", "strides", &fused_conv); + new_nodes->push_back(fused_conv); + + return Status::OK(); + }, + {}, &replaced_graph_def)); +``` + +Here you can see we define the pattern to look for, and in the callback function +use information from each of the nodes in the old sub-graph to create a new +fused node. We also copy over the old weights input node so that isn't lost. + +There are a few things to know about the `ReplaceMatchingOpTypes` function: + +* All of the nodes in any matching sub-graphs are removed from the new graph + created by the function. If any of them are needed, it's the callback + function's responsibility to add them back in. There's a `CopyOriginalMatch` + convenience call that will copy over all of the original nodes if you decide + you don't actually want to modify a particular sub-graph. + +* Nodes will never appear in more than one matched sub-graph. This is to + ensure that sub-trees are only replaced once, but it may mean that some + sub-graphs aren't spotted if they overlap with earlier matches. + +* The calling framework tries to ensure that the graph remains sane, by + looking at the new_nodes that are returned and making sure that no nodes + which are needed as inputs by nodes outside the sub-graph are removed. These + important nodes are listed in the `output_nodes` argument that's passed into + each replacement function call. You can disable this checking by setting + `allow_inconsistencies` to true in the options, but otherwise any + replacements that break the graph constraints will be cancelled. If you do + allow inconsistencies, it's your transform's responsibility to fix them up + before you return your final result. Functions like `RenameNodeInputs` can + be useful if you are doing wholesale node renaming for example. + +### Parameters + +The arguments that are in parentheses after the transform name when the tool is +called are parsed and placed into the params member of the TransformFuncContext +that's given to each transform. For every named argument, there's a vector of +strings containing all the values that it was given, in the order they were +given. These are treated a bit like command-line parameters, and it's the +transform's responsibility to parse them into the data types it needs, and raise +errors by returning a bad Status if any of them are ill-formed. + +As an example, here's a hypothetical transform call: + +``` +some_transform(foo=a, foo=b, bar=2, bob="1,2,3") +``` + +Here's what the std::map of strings looks like in the params member: + +``` +{{"foo", {"a", "b"}}, {"bar", {"2"}}, {"bob", {"1,2,3"}}} +``` + +The double quotes around the comma-separated argument to `bob` are important +because otherwise they'll be treated as separate arguments, and the parsing will +fail. + +Here's an example of how [round_weights](#round_weights) reads its `num_steps` +parameter: + +```C++ +string num_steps_string; +TF_RETURN_IF_ERROR( + GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string)); +int32 num_steps; +if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) { + return errors::InvalidArgument( + "Couldn't interpret the num_steps argument to round_weights as a " + "number:", + num_steps_string); +} +``` + +Things to notice here are that you have to convert the string to an integer, and +if the conversion fails you need to raise a meaningful error through the status +result of the transform. We're also using a helper function which raises an +error if the parameter is present multiple times, and uses a default if the user +hasn't specified it. + +### Function Libraries + +A newer feature of TensorFlow is the ability to create libraries of functions as +part of graphs. These are a bit like templates, which define macro operations in +terms of smaller components, which can then be instantiated with different input +and output connections inside the graph just like regular ops. Right now the +graph transform tool just copies these libraries between the input and output +graphs, but it's likely that more complex operations will be supported on them +in the future. + +### Registering + +The Graph Transform Tool associates names of transforms with the code to +implement them using the `REGISTER_GRAPH_TRANSFORM()` macro. This takes a string +and a function, and automagically registers the transform with the tool. You +will need to watch out for a few things though: + +* Because it's using global C++ objects in each file under the hood, the + linker can sometimes strip them out and lose the registration. In Bazel you + need to make sure you're linking any new transforms in as libraries, and use + the `alwayslink` flag in your `cc_binary` call. + +* You should be able to create your own copy of the transform_graph tool by + linking against the transform_graph_main_lib library in + tensorflow/tools/graph_transforms/BUILD. This contains all the `main()` + logic to parse command line arguments and call transforms. diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc new file mode 100644 index 00000000000..9f3393f1265 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc @@ -0,0 +1,108 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the +// Mul baked into the convolution weights, to save computation during inference. +Status FoldBatchNorms(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"Mul", // mul_node + { + {"Conv2D", // conv_node + { + {"*"}, // input_node + {"Const"}, // weights_node + } + }, + {"Const"}, // mul_values_node + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + // Find all the nodes we expect in the subgraph. + const NodeDef& mul_node = match.node; + const NodeDef& conv_node = match.inputs[0].node; + const NodeDef& input_node = match.inputs[0].inputs[0].node; + const NodeDef& weights_node = match.inputs[0].inputs[1].node; + const NodeDef& mul_values_node = match.inputs[1].node; + + Tensor weights = GetNodeTensorAttr(weights_node, "value"); + Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value"); + + // Make sure all the inputs really are vectors, with as many entries as + // there are columns in the weights. + const int64 weights_cols = weights.shape().dim_size(3); + if ((mul_values.shape().dims() != 1) || + (mul_values.shape().dim_size(0) != weights_cols)) { + return errors::InvalidArgument( + "Mul constant input to batch norm has bad shape: ", + mul_values.shape().DebugString()); + } + + // Multiply the original weights by the scale vector. + auto weights_matrix = weights.flat_inner_dims(); + Tensor scaled_weights(DT_FLOAT, weights.shape()); + auto scaled_weights_matrix = scaled_weights.flat_inner_dims(); + for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { + for (int64 col = 0; col < weights_cols; ++col) { + scaled_weights_matrix(row, col) = + weights_matrix(row, col) * mul_values.flat()(col); + } + } + + // Construct the new nodes. + NodeDef scaled_weights_node; + scaled_weights_node.set_op("Const"); + scaled_weights_node.set_name(weights_node.name()); + SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node); + SetNodeTensorAttr("value", scaled_weights, &scaled_weights_node); + new_nodes->push_back(scaled_weights_node); + + new_nodes->push_back(input_node); + + NodeDef new_conv_node; + new_conv_node.CopyFrom(conv_node); + new_conv_node.set_name(mul_node.name()); + new_nodes->push_back(new_conv_node); + + return Status::OK(); + }, + {}, &replaced_graph_def)); + *output_graph_def = replaced_graph_def; + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("fold_batch_norms", FoldBatchNorms); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc new file mode 100644 index 00000000000..b9983fdd0b2 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status FoldBatchNorms(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class FoldBatchNormsTest : public ::testing::Test { + protected: + void TestFoldBatchNorms() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor mul_values_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&mul_values_data, {2.0f, 3.0f}); + Output mul_values_op = Const(root.WithOpName("mul_values"), + Input::Initializer(mul_values_data)); + + Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK( + FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("Mul", node.op()); + } + } +}; + +TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 8cd44934601..27baa3711c4 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -95,17 +95,16 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def, } Status RemoveUnusedNodes(const GraphDef& input_graph_def, - const std::vector& inputs, - const std::vector& outputs, + const TransformFuncContext& context, GraphDef* output_graph_def) { std::map node_map; MapNamesToNodes(input_graph_def, &node_map); std::map used_nodes; - for (const string& input : inputs) { + for (const string& input : context.input_names) { used_nodes[input] = true; } - std::vector current_nodes = outputs; + std::vector current_nodes = context.output_names; while (!current_nodes.empty()) { std::vector next_nodes; for (const string& node_name : current_nodes) { @@ -134,9 +133,10 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def, return Status::OK(); } +// Converts any sub-graphs that can be resolved into constant expressions into +// single Const ops. Status FoldConstants(const GraphDef& input_graph_def, - const std::vector& inputs, - const std::vector& outputs, + const TransformFuncContext& context, GraphDef* output_graph_def) { // Some older GraphDefs have saved _output_shapes attributes which are out of // date and cause import errors, so clean them up first. @@ -148,20 +148,24 @@ Status FoldConstants(const GraphDef& input_graph_def, ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr)); DeviceAttributes device_attributes; TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( - &input_graph, inputs, outputs, {}, device_attributes)); - if (!DoConstantFolding(ConstantFoldingOptions(), nullptr, Env::Default(), - nullptr, &input_graph)) { - return errors::InvalidArgument("Constant folding failed"); - } + &input_graph, context.input_names, context.output_names, {}, + device_attributes)); + bool was_mutated; + TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus( + ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph, + &was_mutated)); GraphDef folded_graph_def; input_graph.ToGraphDef(&folded_graph_def); GraphDef send_recvs_replaced; - TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def, inputs, - outputs, &send_recvs_replaced)); - TF_RETURN_IF_ERROR(RemoveUnusedNodes(send_recvs_replaced, inputs, outputs, - output_graph_def)); + TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def, + context.input_names, context.output_names, + &send_recvs_replaced)); + TF_RETURN_IF_ERROR( + RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def)); return Status::OK(); } +REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants); + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.h b/tensorflow/tools/graph_transforms/fold_constants_lib.h index a1773da752c..8aefa6ae0f1 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.h +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { namespace graph_transforms { @@ -27,15 +28,13 @@ namespace graph_transforms { // the names of all the nodes that data is fed into, or read out of, when the // graph is actually run. Status FoldConstants(const GraphDef& input_graph_def, - const std::vector& inputs, - const std::vector& outputs, + const TransformFuncContext& context, GraphDef* output_graph_def); // Analyzes which nodes are used for the given set of inputs and outputs, and // returns a copy of the graph with any that aren't used removed. Status RemoveUnusedNodes(const GraphDef& input_graph_def, - const std::vector& inputs, - const std::vector& outputs, + const TransformFuncContext& context, GraphDef* output_graph_def); } // namespace graph_transforms diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index 9a0e452e795..dac13f5c321 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -82,12 +82,13 @@ class ConstantFoldingTest : public ::testing::Test { TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors)); GraphDef folded_graph_def; - std::vector input_names; + graph_transforms::TransformFuncContext context; for (const std::pair& input : inputs) { - input_names.push_back(input.first); + context.input_names.push_back(input.first); } - TF_ASSERT_OK(graph_transforms::FoldConstants(graph_def, input_names, - outputs, &folded_graph_def)); + context.output_names = outputs; + TF_ASSERT_OK( + graph_transforms::FoldConstants(graph_def, context, &folded_graph_def)); std::unique_ptr folded_session( tensorflow::NewSession(tensorflow::SessionOptions())); @@ -187,7 +188,7 @@ class ConstantFoldingTest : public ::testing::Test { TF_ASSERT_OK(root.ToGraphDef(&graph_def)); GraphDef result_graph_def; TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( - graph_def, {"placeholder"}, {"output"}, &result_graph_def)); + graph_def, {{"placeholder"}, {"output"}}, &result_graph_def)); std::map node_map; graph_transforms::MapNamesToNodes(result_graph_def, &node_map); diff --git a/tensorflow/tools/graph_transforms/fold_constants_tool.cc b/tensorflow/tools/graph_transforms/fold_constants_tool.cc deleted file mode 100644 index bfcbdf6b144..00000000000 --- a/tensorflow/tools/graph_transforms/fold_constants_tool.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Utility that transforms a model with subgraphs that evaluate to constant -// functions into the equivalent model with those subgraphs replaced by Const -// nodes. This simplifies the graph, and makes some further transformations -// easier to perform. It's often useful to run the freeze_graph tool on the -// input graph beforehand to ensure variables have been transformed to Consts. -// -// bazel-bin/tensorflow/tools/graph_transforms/fold_constants_tool \ -// --in_graph=graph_def.pb \ -// --out_graph=folded_graph_def.pb \ -// --inputs=input1,input2 \ -// --outputs=output1,output2 -// -// Parameters: -// in_graph - name of a file with a frozen GraphDef proto in binary format. -// out_graph - name of the output file to save the folded version to. -// inputs - layer names of the nodes that will be fed data. -// outputs - layer names of the nodes that will be read from after running. - -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" - -namespace tensorflow { -namespace { - -int ParseFlagsAndConvertGraph(int argc, char* argv[]) { - string in_graph = ""; - string out_graph = ""; - string inputs_string = ""; - string outputs_string = ""; - std::vector flag_list = { - Flag("in_graph", &in_graph, "input graph file name"), - Flag("out_graph", &out_graph, "output graph file name"), - Flag("inputs", &inputs_string, "inputs"), - Flag("outputs", &outputs_string, "outputs"), - }; - string usage = Flags::Usage(argv[0], flag_list); - const bool parse_result = Flags::Parse(&argc, argv, flag_list); - // We need to call this to set up global state for TensorFlow. - port::InitMain(argv[0], &argc, &argv); - if (!parse_result) { - LOG(ERROR) << usage; - return -1; - } - if (argc > 1) { - LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; - return -1; - } - if (in_graph.empty()) { - LOG(ERROR) << "in_graph graph can't be empty"; - return -1; - } - if (out_graph.empty()) { - LOG(ERROR) << "out_graph graph can't be empty"; - return -1; - } - std::vector inputs = str_util::Split(inputs_string, ','); - std::vector outputs = str_util::Split(outputs_string, ','); - - GraphDef graph_def; - Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def); - if (!load_status.ok()) { - LOG(ERROR) << "Loading graph '" << in_graph << "' failed with " - << load_status.error_message(); - return -1; - } - - GraphDef folded_graph_def; - Status folding_result = graph_transforms::FoldConstants( - graph_def, inputs, outputs, &folded_graph_def); - if (!folding_result.ok()) { - LOG(ERROR) << "Folding failed " << folding_result.error_message(); - return -1; - } - - Status save_status = - WriteBinaryProto(Env::Default(), out_graph, folded_graph_def); - if (!save_status.ok()) { - LOG(ERROR) << "Saving graph '" << out_graph << "' failed with " - << save_status.error_message(); - return -1; - } - - return 0; -} - -} // namespace -} // namespace tensorflow - -int main(int argc, char* argv[]) { - return tensorflow::ParseFlagsAndConvertGraph(argc, argv); -} diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc new file mode 100644 index 00000000000..066727614c8 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -0,0 +1,193 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { +namespace { +// Ensures the tensor is the expected shape. +Status ErrorIfNotVector(const Tensor& input, const string& input_name, + int expected_width) { + if ((input.shape().dims() != 1) || + (input.shape().dim_size(0) != expected_width)) { + return errors::InvalidArgument(input_name, + " input to batch norm has bad shape: ", + input.shape().DebugString()); + } + return Status::OK(); +} +} // namespace + +// Finds monolithic batch norm ops (as used in early versions of TensorFlow) and +// converts them into premultiplied weight inputs to convolutions. +Status FoldOldBatchNorms(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef current_graph_def = input_graph_def; + // We have to do several passes to catch all the old BN nodes, since many of + // them may share inputs and so be excluded from replacement in one pass. + bool did_graph_change; + do { + did_graph_change = false; + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, // clang-format off + {"BatchNormWithGlobalNormalization", // batch_norm_node + { + {"Conv2D", // conv_node + { + {"*"}, // input_node + {"Const"}, // weights_node + } + }, + {"Const"}, // mean_node + {"Const"}, // variance_node + {"Const"}, // beta_node + {"Const"}, // gamma_node + } + }, // clang-format on + [&did_graph_change](const NodeMatch& match, + const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + // Find all the nodes we expect in the subgraph. + const NodeDef& batch_norm_node = match.node; + CHECK_EQ("BatchNormWithGlobalNormalization", batch_norm_node.op()); + const NodeDef& conv_node = match.inputs[0].node; + CHECK_EQ("Conv2D", conv_node.op()); + const NodeDef& input_node = match.inputs[0].inputs[0].node; + const NodeDef& weights_node = match.inputs[0].inputs[1].node; + CHECK_EQ("Const", weights_node.op()); + const NodeDef& mean_node = match.inputs[1].node; + CHECK_EQ("Const", mean_node.op()); + const NodeDef& variance_node = match.inputs[2].node; + CHECK_EQ("Const", variance_node.op()); + const NodeDef& beta_node = match.inputs[3].node; + CHECK_EQ("Const", beta_node.op()); + const NodeDef& gamma_node = match.inputs[4].node; + CHECK_EQ("Const", gamma_node.op()); + + // We have a set of vectors that we want to combine into a vector of + // scale values to apply column-wise to the weight input to the conv, + // and an offset vector that we'll apply to the output of the conv. + Tensor weights = GetNodeTensorAttr(weights_node, "value"); + Tensor mean = GetNodeTensorAttr(mean_node, "value"); + Tensor variance = GetNodeTensorAttr(variance_node, "value"); + Tensor beta = GetNodeTensorAttr(beta_node, "value"); + Tensor gamma = GetNodeTensorAttr(gamma_node, "value"); + const float variance_epsilon = + batch_norm_node.attr().at("variance_epsilon").f(); + const bool scale_after_normalization = + batch_norm_node.attr().at("scale_after_normalization").b(); + + // Make sure all the inputs really are vectors, with as many entries + // as there are columns in the weights. + const int64 weights_cols = weights.shape().dim_size(3); + TF_RETURN_IF_ERROR(ErrorIfNotVector(mean, "Mean", weights_cols)); + TF_RETURN_IF_ERROR( + ErrorIfNotVector(variance, "Variance", weights_cols)); + TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", weights_cols)); + TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", weights_cols)); + + // Calculate the scale and offset values to apply. + std::vector scale_values(weights_cols); + std::vector offset_values(weights_cols); + if (scale_after_normalization) { + for (int i = 0; i < weights_cols; ++i) { + scale_values[i] = + (1.0f / sqrtf(variance.flat()(i) + variance_epsilon)) * + gamma.flat()(i); + offset_values[i] = 0.0f; + } + } else { + for (int i = 0; i < weights_cols; ++i) { + scale_values[i] = + (1.0f / sqrtf(variance.flat()(i) + variance_epsilon)); + offset_values[i] = (-mean.flat()(i) * scale_values[i]) + + beta.flat()(i); + } + } + + // Multiply the original weights by the scale vector. + auto weights_matrix = weights.flat_inner_dims(); + Tensor scaled_weights(DT_FLOAT, weights.shape()); + auto scaled_weights_matrix = scaled_weights.flat_inner_dims(); + for (int64 row = 0; row < weights_matrix.dimension(0); ++row) { + for (int64 col = 0; col < weights_cols; ++col) { + scaled_weights_matrix(row, col) = + weights_matrix(row, col) * scale_values[col]; + } + } + // Figure out the remaining bias to add on. + Tensor bias_offset(DT_FLOAT, {weights_cols}); + auto bias_offset_vector = bias_offset.flat(); + for (int64 col = 0; col < weights_cols; ++col) { + bias_offset_vector(col) = offset_values[col]; + } + + // Construct the new nodes. + NodeDef scaled_weights_node; + scaled_weights_node.set_op("Const"); + scaled_weights_node.set_name(weights_node.name()); + SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node); + SetNodeTensorAttr("value", scaled_weights, + &scaled_weights_node); + new_nodes->push_back(scaled_weights_node); + + // The input and convolution can be copied straight over, since the + // name of the scaled weights constant is the same as the original. + new_nodes->push_back(input_node); + new_nodes->push_back(conv_node); + + NodeDef bias_offset_node; + bias_offset_node.set_op("Const"); + bias_offset_node.set_name(conv_node.name() + "_bn_offset"); + SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node); + SetNodeTensorAttr("value", bias_offset, &bias_offset_node); + new_nodes->push_back(bias_offset_node); + + NodeDef bias_add_node; + bias_add_node.set_op("BiasAdd"); + bias_add_node.set_name(batch_norm_node.name()); + CopyNodeAttr(conv_node, "T", "T", &bias_add_node); + AddNodeInput(conv_node.name(), &bias_add_node); + AddNodeInput(bias_offset_node.name(), &bias_add_node); + new_nodes->push_back(bias_add_node); + + did_graph_change = true; + + return Status::OK(); + }, + {}, &replaced_graph_def)); + current_graph_def = replaced_graph_def; + } while (did_graph_change); + *output_graph_def = current_graph_def; + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc new file mode 100644 index 00000000000..1c4958d83c9 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status FoldOldBatchNorms(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class FoldOldBatchNormsTest : public ::testing::Test { + protected: + void TestFoldOldBatchNorms() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + Tensor mean_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&mean_data, {10.0f, 20.0f}); + Output mean_op = + Const(root.WithOpName("mean_op"), Input::Initializer(mean_data)); + + Tensor variance_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&variance_data, {0.25f, 0.5f}); + Output variance_op = Const(root.WithOpName("variance_op"), + Input::Initializer(variance_data)); + + Tensor beta_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&beta_data, {0.1f, 0.6f}); + Output beta_op = + Const(root.WithOpName("beta_op"), Input::Initializer(beta_data)); + + Tensor gamma_data(DT_FLOAT, TensorShape({2})); + test::FillValues(&gamma_data, {1.0f, 2.0f}); + Output gamma_op = + Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data)); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + // This is needed because we're trying to convert over a deprecated op which + // should only be present in older GraphDef files. Without this we see a + // deprecation error. + // This is justified because we're trying to test a tool that is expected to + // run on legacy files, to help users convert over to less problematic + // versions. + NodeDef batch_norm_node; + batch_norm_node.set_op("BatchNormWithGlobalNormalization"); + batch_norm_node.set_name("output"); + AddNodeInput("conv_op", &batch_norm_node); + AddNodeInput("mean_op", &batch_norm_node); + AddNodeInput("variance_op", &batch_norm_node); + AddNodeInput("beta_op", &batch_norm_node); + AddNodeInput("gamma_op", &batch_norm_node); + SetNodeAttr("T", DT_FLOAT, &batch_norm_node); + SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node); + SetNodeAttr("scale_after_normalization", false, &batch_norm_node); + *(original_graph_def.mutable_node()->Add()) = batch_norm_node; + original_graph_def.mutable_versions()->set_producer(8); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("BatchNormWithGlobalNormalization", node.op()); + } + } +}; + +TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) { + TestFoldOldBatchNorms(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fuse_convolutions.cc b/tensorflow/tools/graph_transforms/fuse_convolutions.cc new file mode 100644 index 00000000000..df6e9e6dc28 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fuse_convolutions.cc @@ -0,0 +1,200 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +Status FuseResizePadAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"Conv2D", + { + {"MirrorPad", + { + {"ResizeBilinear"}, + {"*"} + } + }, + {"*"} + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + // Find all the nodes we expect in the subgraph. + const NodeDef& conv_node = match.node; + const NodeDef& mirror_pad_node = match.inputs[0].node; + const NodeDef& weights_node = match.inputs[1].node; + const NodeDef& resize_node = match.inputs[0].inputs[0].node; + const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node; + + // We'll be reusing the old weights and pad dimensions. + new_nodes->push_back(weights_node); + new_nodes->push_back(pad_dims_node); + + // Set up the new fused version of the convolution op. + NodeDef fused_conv; + fused_conv.set_op("FusedResizeAndPadConv2D"); + fused_conv.set_name(match.node.name()); + AddNodeInput(resize_node.input(0), &fused_conv); + AddNodeInput(resize_node.input(1), &fused_conv); + AddNodeInput(mirror_pad_node.input(1), &fused_conv); + AddNodeInput(conv_node.input(1), &fused_conv); + CopyNodeAttr(resize_node, "align_corners", "resize_align_corners", + &fused_conv); + CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv); + CopyNodeAttr(conv_node, "T", "T", &fused_conv); + CopyNodeAttr(conv_node, "padding", "padding", &fused_conv); + CopyNodeAttr(conv_node, "strides", "strides", &fused_conv); + new_nodes->push_back(fused_conv); + + return Status::OK(); + }, + {}, &replaced_graph_def)); + *output_graph_def = replaced_graph_def; + return Status::OK(); +} + +Status FuseResizeAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"Conv2D", + { + {"ResizeBilinear"}, + {"*"} + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + // Find all the nodes we expect in the subgraph. + const NodeDef& conv_node = match.node; + const NodeDef& resize_node = match.inputs[0].node; + const NodeDef& weights_node = match.inputs[1].node; + + // We'll be reusing the old weights. + new_nodes->push_back(weights_node); + + // Create a 'no-op' mirror padding node that has no effect. + NodeDef pad_dims_node; + pad_dims_node.set_op("Const"); + pad_dims_node.set_name(conv_node.name() + "_dummy_paddings"); + SetNodeAttr("dtype", DT_INT32, &pad_dims_node); + SetNodeTensorAttr("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0}, + &pad_dims_node); + new_nodes->push_back(pad_dims_node); + + // Set up the new fused version of the convolution op. + NodeDef fused_conv; + fused_conv.set_op("FusedResizeAndPadConv2D"); + fused_conv.set_name(match.node.name()); + AddNodeInput(resize_node.input(0), &fused_conv); + AddNodeInput(resize_node.input(1), &fused_conv); + AddNodeInput(pad_dims_node.name(), &fused_conv); + AddNodeInput(conv_node.input(1), &fused_conv); + CopyNodeAttr(resize_node, "align_corners", "resize_align_corners", + &fused_conv); + SetNodeAttr("mode", "REFLECT", &fused_conv); + CopyNodeAttr(conv_node, "T", "T", &fused_conv); + CopyNodeAttr(conv_node, "padding", "padding", &fused_conv); + CopyNodeAttr(conv_node, "strides", "strides", &fused_conv); + new_nodes->push_back(fused_conv); + + return Status::OK(); + }, + {}, &replaced_graph_def)); + *output_graph_def = replaced_graph_def; + return Status::OK(); +} + +Status FusePadAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"Conv2D", + { + {"MirrorPad", + { + {"*"}, + {"*"}, + } + }, + {"*"} + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + // Find all the nodes we expect in the subgraph. + const NodeDef& conv_node = match.node; + CHECK_EQ("Conv2D", conv_node.op()); + const NodeDef& mirror_pad_node = match.inputs[0].node; + CHECK_EQ("MirrorPad", mirror_pad_node.op()); + const NodeDef& weights_node = match.inputs[1].node; + const NodeDef& input_node = match.inputs[0].inputs[0].node; + const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node; + + // We'll be reusing the old weights and pad dimensions. + new_nodes->push_back(weights_node); + new_nodes->push_back(input_node); + new_nodes->push_back(pad_dims_node); + + // Set up the new fused version of the convolution op. + NodeDef fused_conv; + fused_conv.set_op("FusedPadConv2D"); + fused_conv.set_name(match.node.name()); + AddNodeInput(mirror_pad_node.input(0), &fused_conv); + AddNodeInput(mirror_pad_node.input(1), &fused_conv); + AddNodeInput(conv_node.input(1), &fused_conv); + CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv); + CopyNodeAttr(conv_node, "T", "T", &fused_conv); + CopyNodeAttr(conv_node, "padding", "padding", &fused_conv); + CopyNodeAttr(conv_node, "strides", "strides", &fused_conv); + new_nodes->push_back(fused_conv); + + return Status::OK(); + }, + {}, &replaced_graph_def)); + *output_graph_def = replaced_graph_def; + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("fuse_resize_pad_and_conv", FuseResizePadAndConv); + +REGISTER_GRAPH_TRANSFORM("fuse_resize_and_conv", FuseResizeAndConv); + +REGISTER_GRAPH_TRANSFORM("fuse_pad_and_conv", FusePadAndConv); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc b/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc new file mode 100644 index 00000000000..b315b9caba1 --- /dev/null +++ b/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc @@ -0,0 +1,212 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status FuseResizePadAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status FuseResizeAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status FusePadAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class FuseConvolutionsTest : public ::testing::Test { + protected: + void TestFuseResizePadAndConv() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op, + Const(root.WithOpName("size"), {12, 4}), + ResizeBilinear::AlignCorners(false)); + + Tensor pad_dims_data(DT_INT32, TensorShape({4, 2})); + test::FillValues(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0}); + Output pad_dims_op = Const(root.WithOpName("pad_dims_op"), + Input::Initializer(pad_dims_data)); + Output pad_op = + MirrorPad(root.WithOpName("pad_op"), resize_op, pad_dims_op, "REFLECT"); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FuseResizePadAndConv(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("Conv2D", node.op()); + EXPECT_NE("MirrorPad", node.op()); + EXPECT_NE("ResizeBilinear", node.op()); + } + } + + void TestFuseResizeAndConv() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op, + Const(root.WithOpName("size"), {12, 4}), + ResizeBilinear::AlignCorners(false)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("output"), resize_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK(FuseResizeAndConv(original_graph_def, {{}, {"output"}}, + &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("Conv2D", node.op()); + EXPECT_NE("ResizeBilinear", node.op()); + } + } + + void TestFusePadAndConv() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor pad_dims_data(DT_INT32, TensorShape({4, 2})); + test::FillValues(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0}); + Output pad_dims_op = Const(root.WithOpName("pad_dims_op"), + Input::Initializer(pad_dims_data)); + Output pad_op = + MirrorPad(root.WithOpName("pad_op"), input_op, pad_dims_op, "REFLECT"); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2})); + test::FillValues(&weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef fused_graph_def; + TF_ASSERT_OK( + FusePadAndConv(original_graph_def, {{}, {"output"}}, &fused_graph_def)); + + std::unique_ptr fused_session(NewSession(SessionOptions())); + TF_ASSERT_OK(fused_session->Create(fused_graph_def)); + std::vector fused_outputs; + TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); + + test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 1e-5); + + for (const NodeDef& node : fused_graph_def.node()) { + EXPECT_NE("Conv2D", node.op()); + EXPECT_NE("MirrorPad", node.op()); + } + } +}; + +TEST_F(FuseConvolutionsTest, TestFuseResizePadAndConv) { + TestFuseResizePadAndConv(); +} + +TEST_F(FuseConvolutionsTest, TestFuseResizeAndConv) { TestFuseResizeAndConv(); } + +TEST_F(FuseConvolutionsTest, TestFusePadAndConv) { TestFusePadAndConv(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/obsfucate_names.cc b/tensorflow/tools/graph_transforms/obsfucate_names.cc new file mode 100644 index 00000000000..00eb0d01b02 --- /dev/null +++ b/tensorflow/tools/graph_transforms/obsfucate_names.cc @@ -0,0 +1,104 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Renames all nodes not uses as graph inputs or outputs to short numerical +// forms. +Status ObsfucateNames(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + std::unordered_set required_nodes; + for (const string& input : context.input_names) { + required_nodes.insert(input); + } + for (const string& output : context.output_names) { + required_nodes.insert(output); + } + + for (const string& required_node : required_nodes) { + LOG(INFO) << "required_node=" << required_node; + } + + const string valid_chars = + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + const int64 chars_size = valid_chars.size(); + + std::map new_names; + int64 name_index = 0; + for (const NodeDef& input_node : input_graph_def.node()) { + const string& old_name = input_node.name(); + string new_name; + if (required_nodes.count(old_name)) { + new_name = old_name; + } else { + do { + int64 remaining = name_index; + new_name = ""; + while (true) { + const int64 remainder = (remaining % chars_size); + const char current_char = valid_chars[remainder]; + new_name = current_char + new_name; + remaining /= chars_size; + if (remaining <= 0) { + break; + } + } + ++name_index; + } while (required_nodes.count(new_name)); + } + new_names[old_name] = new_name; + } + + output_graph_def->Clear(); + for (const NodeDef& input_node : input_graph_def.node()) { + NodeDef* node = output_graph_def->mutable_node()->Add(); + node->CopyFrom(input_node); + const string& old_name = input_node.name(); + node->set_name(new_names[old_name]); + node->mutable_input()->Clear(); + for (const string& input_name : input_node.input()) { + string prefix; + string input_node_name; + string suffix; + NodeNamePartsFromInput(input_name, &prefix, &input_node_name, &suffix); + if (new_names.count(input_node_name) == 0) { + return errors::InvalidArgument("No node named ", input_node_name, + " for input to ", old_name); + } + string new_input_name = prefix + new_names[input_node_name] + suffix; + *(node->mutable_input()->Add()) = new_input_name; + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/obsfucate_names_test.cc b/tensorflow/tools/graph_transforms/obsfucate_names_test.cc new file mode 100644 index 00000000000..90b34a707ab --- /dev/null +++ b/tensorflow/tools/graph_transforms/obsfucate_names_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status ObsfucateNames(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class ObsfucateNamesTest : public ::testing::Test { + protected: + void TestSimpleTree() { + GraphDef graph_def; + + NodeDef* add_node1 = graph_def.add_node(); + add_node1->set_name("add_node1"); + add_node1->set_op("Add"); + add_node1->add_input("add_node2"); + add_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node3"); + add_node3->add_input("const_node4"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* const_node4 = graph_def.add_node(); + const_node4->set_name("const_node4"); + const_node4->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK( + ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + + EXPECT_EQ(1, node_lookup.count("add_node1")); + EXPECT_EQ(0, node_lookup.count("add_node2")); + EXPECT_EQ(0, node_lookup.count("add_node3")); + EXPECT_EQ(1, node_lookup.count("const_node1")); + EXPECT_EQ(0, node_lookup.count("const_node2")); + EXPECT_EQ(0, node_lookup.count("const_node3")); + EXPECT_EQ(0, node_lookup.count("const_node4")); + } + + void TestManyNodes() { + GraphDef graph_def; + for (int i = 0; i < 1000; ++i) { + NodeDef* const_node = graph_def.add_node(); + const_node->set_name(strings::StrCat("const_node", i)); + const_node->set_op("Const"); + } + + GraphDef result; + TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}}, + &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("const_node0")); + EXPECT_EQ(0, node_lookup.count("const_node500")); + EXPECT_EQ(1, node_lookup.count("const_node999")); + } + + void TestNameClashes() { + GraphDef graph_def; + for (int i = 0; i < 1000; ++i) { + NodeDef* const_node = graph_def.add_node(); + const_node->set_name(strings::StrCat("1", i)); + const_node->set_op("Const"); + } + + GraphDef result; + TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("10")); + EXPECT_EQ(1, node_lookup.count("19")); + + std::unordered_set names; + for (const NodeDef& node : result.node()) { + EXPECT_EQ(0, names.count(node.name())) + << "Found multiple nodes with name '" << node.name() << "'"; + names.insert(node.name()); + } + } +}; + +TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); } + +TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); } + +TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc new file mode 100644 index 00000000000..8b0393049ac --- /dev/null +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -0,0 +1,922 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Holds the information we need to translate from a float version of this op +// into the quantized equivalent. +struct QuantizedOpInfo { + // The name of the float op. + string float_name; + // Which attributes to copy directly over. + std::vector attrs_to_copy; + // Extra data type attributes we need to set. + std::vector> dtypes_to_set; + // What depth of inputs the op can read in. + DataType input_bit_depth; + // The depth of the op's quantized outputs. + DataType output_bit_depth; + // Which inputs (e.g. shapes) aren't involved in the quantization process. + std::set unquantized_inputs; + // How the outputs are arranged, either + // [input0, input1, min0, max0, min1, max1] for contiguous, or + // [input0, input1, min0, min1, max0, max1] for separate. + // The separate order is needed because it's the only way to specify unknown + // numbers of inputs for ops like Concat. + enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order; +}; + +// Every op that has a quantized equivalent should be listed here, so that the +// conversion process can transform them. +const std::vector& GetQuantizedOpList() { + static const std::vector op_list = { + {"AvgPool", + {"ksize", "strides", "padding"}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"BiasAdd", + {}, + {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}}, + DT_QUINT8, + DT_QINT32, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"Concat", + {"N"}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {0}, + QuantizedOpInfo::SEPARATE_MIN_MAX}, + {"Conv2D", + {"strides", "padding"}, + {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}}, + DT_QUINT8, + DT_QINT32, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"MatMul", + {"transpose_a", "transpose_b"}, + {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, + DT_QUINT8, + DT_QINT32, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"MaxPool", + {"ksize", "strides", "padding"}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"Relu", + {}, + {{"Tinput", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"Relu6", + {}, + {{"Tinput", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"Reshape", + {}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {1}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + }; + return op_list; +} + +namespace { +// Replaces invalid characters in input names to get a unique node name. +string UniqueNodeNameFromInput(const string& input_name) { + string prefix; + string node_name; + string suffix; + NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); + string result; + if (prefix == "^") { + result += "__hat__"; + } + result += node_name; + if (suffix != "") { + result += "__port__" + suffix.substr(1, suffix.size() - 1); + } + return result; +} + +// Pulls two float values from the named parameters, with a lot of checking. +Status ExtractRangeFromParams(const TransformFuncContext& context, + const string& min_name, const string& max_name, + float* min_value, float* max_value, + bool* has_range) { + // See if we've been given quantized inputs with a known range. + const bool has_min = (context.params.count(min_name) != 0); + const bool has_max = (context.params.count(max_name) != 0); + *has_range = (has_min || has_max); + if (!*has_range) { + return Status::OK(); + } + if (!has_min || !has_max) { + return errors::InvalidArgument("You must pass both ", min_name, " and ", + max_name, " into quantize_nodes"); + } + std::vector min_strings = context.params.at(min_name); + std::vector max_strings = context.params.at(max_name); + if ((min_strings.size() != 1) || (max_strings.size() != 1)) { + return errors::InvalidArgument("You must pass a single ", min_name, + " and single ", max_name, + " value into " + "quantize_nodes"); + } + if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) { + return errors::InvalidArgument("Couldn't decode ", min_name, + " as a number: ", min_strings[0]); + } + if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) { + return errors::InvalidArgument("Couldn't decode ", max_name, + " as a number: ", max_strings[0]); + } + return Status::OK(); +} + +bool AreAttrsEqual(const NodeDef* current_node, const NodeDef* other_node) { + if (current_node->attr_size() != other_node->attr_size()) { + return false; + } + string current_serialized; + string other_serialized; + for (const auto& attr : other_node->attr()) { + auto iter = current_node->attr().find(attr.first); + if (iter == current_node->attr().end()) return false; + iter->second.SerializeToString(¤t_serialized); + attr.second.SerializeToString(&other_serialized); + if (current_serialized != other_serialized) return false; + } + return true; +} + +} // namespace + +// Analyzes all the nodes in the graph to figure out which ones are duplicates +// apart from their names. This commonly includes identical Const nodes, but can +// also be simple operations that are repeated on multiple outputs of a +// particular node. The complexity is managed using a hash function that avoids +// the need for any O(n^2) algorithms when identifying duplicates. +Status MergeDuplicateNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + // Make sure we can look up inputs and outputs quickly. + std::set input_names(context.input_names.begin(), + context.input_names.end()); + std::set output_names(context.output_names.begin(), + context.output_names.end()); + GraphDef current_graph_def = input_graph_def; + // Keep running the merging until no more duplicates are found. + bool any_duplicates_found; + do { + any_duplicates_found = false; + // First arrange all of the nodes by a hash of their contents. + std::map> hashed_nodes; + for (const NodeDef& node : current_graph_def.node()) { + NodeDef nameless_node = node; + // The name matters if it's being used as an input or output node, + // otherwise ignore it when looking for duplicates. + if (!input_names.count(node.name()) && !output_names.count(node.name())) { + nameless_node.set_name(""); + } + const uint64 hash = HashNodeDef(nameless_node); + hashed_nodes[hash].push_back(&node); + } + // If we have multiple nodes with the same hash, then we know they're + // duplicates and can be removed, unless they're stateful. + std::map inputs_to_rename; + GraphDef merged_graph_def; + for (const std::pair> hashed_node_info : + hashed_nodes) { + const std::vector& hash_node_list = + hashed_node_info.second; + for (int i = 0; i < hash_node_list.size(); ++i) { + const NodeDef* current_node = hash_node_list[i]; + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR( + OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def)); + const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0)); + if (is_duplicate) { + const string original_name = hash_node_list[0]->name(); + inputs_to_rename[current_node->name() + ":*"] = original_name; + any_duplicates_found = true; + } else { + NodeDef* new_node = merged_graph_def.mutable_node()->Add(); + *new_node = *current_node; + } + } + } + // Update the graph so that any nodes that referred to removed inputs now + // pull from the remaining duplicate. + RenameNodeInputs(merged_graph_def, inputs_to_rename, ¤t_graph_def); + } while (any_duplicates_found); + + *output_graph_def = current_graph_def; + + return Status::OK(); +} + +// Looks for the patterns that indicate there are two eight-bit ops feeding into +// each other, separated by a conversion up to float and back again. These occur +// during the initial conversion of ops to their quantized forms. Because we're +// only looking at an individual op in that phase and don't know if its inputs +// and outputs are eight-bit-capable, we start by converting the actual op into +// quantized form, but add float conversions before and after. This pass gets +// rid of those conversions if it turns out we do have adjacent ops capable of +// eight-bit processing. +Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + std::set graph_outputs; + for (const string& output_name : context.output_names) { + graph_outputs.insert(NodeNameFromInput(output_name)); + } + std::map inputs_to_rename; + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"QuantizeV2", + { + {"Dequantize"}, + {"Min"}, + {"Max"}, + } + }, // clang-format on + [&inputs_to_rename, &graph_outputs](const NodeMatch& match, + const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& quantize_node = match.node; + const NodeDef& dequantize_node = match.inputs[0].node; + inputs_to_rename[quantize_node.name() + ":0"] = + dequantize_node.input(0); + inputs_to_rename[quantize_node.name() + ":1"] = + dequantize_node.input(1); + inputs_to_rename[quantize_node.name() + ":2"] = + dequantize_node.input(2); + + // Are other sub-graphs using the float intermediate result? If so, + // preserve it, but the input renaming still rewires the eight-bit ops + // so they don't go through float. + if (output_nodes.count(dequantize_node.name()) || + graph_outputs.count(dequantize_node.name())) { + CopyOriginalMatch(match, new_nodes); + } + + return Status::OK(); + }, + {true}, &replaced_graph_def)); + + RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def); + + return Status::OK(); +} + +// If the user has passed in the input_min and input_max args, then we need to +// convert any input placeholders from float to eight bit, so quantized inputs +// can be fed directly into the graph. +Status QuantizePlaceholders(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + float input_min; + float input_max; + bool has_input_range; + TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max", + &input_min, &input_max, + &has_input_range)); + if (!has_input_range) { + *output_graph_def = input_graph_def; + return Status::OK(); + } + std::map inputs_to_rename_first_pass; + std::map inputs_to_rename_second_pass; + GraphDef placeholder_graph_def; + placeholder_graph_def.Clear(); + for (const NodeDef& node : input_graph_def.node()) { + if (node.op() != "Placeholder") { + (placeholder_graph_def.mutable_node()->Add())->CopyFrom(node); + } else { + string namespace_prefix = node.name() + "_eightbit"; + + NodeDef quantized_placeholder; + quantized_placeholder.CopyFrom(node); + SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder); + (placeholder_graph_def.mutable_node()->Add()) + ->CopyFrom(quantized_placeholder); + + NodeDef min_node; + min_node.set_op("Const"); + min_node.set_name(namespace_prefix + "/min"); + SetNodeAttr("dtype", DT_FLOAT, &min_node); + Tensor min_tensor(DT_FLOAT, {}); + min_tensor.flat()(0) = input_min; + SetNodeTensorAttr("value", min_tensor, &min_node); + (placeholder_graph_def.mutable_node()->Add())->CopyFrom(min_node); + + NodeDef max_node; + max_node.set_op("Const"); + max_node.set_name(namespace_prefix + "/max"); + SetNodeAttr("dtype", DT_FLOAT, &max_node); + Tensor max_tensor(DT_FLOAT, {}); + max_tensor.flat()(0) = input_max; + SetNodeTensorAttr("value", max_tensor, &max_node); + (placeholder_graph_def.mutable_node()->Add())->CopyFrom(max_node); + + const string rename_suffix = "__RENAMED_PLACEHOLDER__"; + NodeDef dequantize_node; + dequantize_node.set_op("Dequantize"); + dequantize_node.set_name(namespace_prefix + "/dequantize"); + SetNodeAttr("T", DT_QUINT8, &dequantize_node); + SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); + AddNodeInput(node.name() + rename_suffix, &dequantize_node); + AddNodeInput(min_node.name(), &dequantize_node); + AddNodeInput(max_node.name(), &dequantize_node); + (placeholder_graph_def.mutable_node()->Add())->CopyFrom(dequantize_node); + + // First make sure that any internal references to the old placeholder + // now point to the dequantize result. + inputs_to_rename_first_pass[node.name()] = dequantize_node.name(); + // Then fix up the dequantize op so that it really points to the + // placeholder. + inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name(); + } + } + + GraphDef first_pass_graph_def; + RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass, + &first_pass_graph_def); + RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass, + output_graph_def); + + return Status::OK(); +} + +// During training, FakeQuantWithMinMaxVars ops capture a good min/max range for +// an activation layer. To use these during inference, this pass converts those +// ops into Requantizes with the trained min/maxes as constant inputs. +Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"FakeQuantWithMinMaxVars", + { + {"*"}, + {"Const"}, + {"Const"}, + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& fake_quant_node = match.node; + const NodeDef& original_op_node = match.inputs[0].node; + const NodeDef& fake_quant_min_node = match.inputs[1].node; + const NodeDef& fake_quant_max_node = match.inputs[2].node; + + string namespace_prefix = fake_quant_node.name() + "_eightbit"; + + new_nodes->push_back(original_op_node); + new_nodes->push_back(fake_quant_min_node); + new_nodes->push_back(fake_quant_max_node); + + NodeDef quantize_node; + quantize_node.set_op("QuantizeV2"); + quantize_node.set_name(namespace_prefix + "/quantize"); + SetNodeAttr("T", DT_QINT32, &quantize_node); + SetNodeAttr("mode", "MIN_FIRST", &quantize_node); + AddNodeInput(fake_quant_node.input(0), &quantize_node); + AddNodeInput(fake_quant_min_node.name(), &quantize_node); + AddNodeInput(fake_quant_max_node.name(), &quantize_node); + new_nodes->push_back(quantize_node); + + NodeDef requantize_node; + requantize_node.set_op("Requantize"); + requantize_node.set_name(namespace_prefix + "/requantize"); + SetNodeAttr("Tinput", DT_QINT32, &requantize_node); + SetNodeAttr("out_type", DT_QUINT8, &requantize_node); + AddNodeInput(quantize_node.name() + ":0", &requantize_node); + AddNodeInput(quantize_node.name() + ":1", &requantize_node); + AddNodeInput(quantize_node.name() + ":2", &requantize_node); + AddNodeInput(fake_quant_min_node.name(), &requantize_node); + AddNodeInput(fake_quant_max_node.name(), &requantize_node); + new_nodes->push_back(requantize_node); + + // Convert the 8-bit result back into float for the final output. + NodeDef dequantize_node; + dequantize_node.set_op("Dequantize"); + dequantize_node.set_name(fake_quant_node.name()); + SetNodeAttr("T", DT_QUINT8, &dequantize_node); + SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); + AddNodeInput(requantize_node.name() + ":0", &dequantize_node); + AddNodeInput(requantize_node.name() + ":1", &dequantize_node); + AddNodeInput(requantize_node.name() + ":2", &dequantize_node); + new_nodes->push_back(dequantize_node); + + return Status::OK(); + }, + {}, output_graph_def)); + + return Status::OK(); +} + +// We always generate Requantize ops driven by dynamic RequantizationRange +// calculations when we produce quantized ops like Conv2D or BiasAdd with +// 32-bit results. If there were FakeQuant ops already for those activation +// layers, then there will be a later Requantize op with constant min/max +// inputs, which is preferable for fast inference. This pass looks for those +// later Requantize ops, and replaces the dynamic version with them. +Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, // clang-format off + {"Requantize", + { + {"QuantizeV2", + { + {"Dequantize", + { + {"Requantize", + { + {"*"}, + {"*"}, + {"*"}, + {"RequantizationRange"}, + {"RequantizationRange"}, + } + }, + {"Requantize"}, + {"Requantize"}, + } + }, + {"Const"}, + {"Const"}, + }, + }, + {"QuantizeV2"}, + {"QuantizeV2"}, + {"Const"}, + {"Const"}, + } + }, // clang-format on + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& fake_requantize_node = match.node; + const NodeDef& original_op_node = + match.inputs[0].inputs[0].inputs[0].inputs[0].node; + const NodeDef& fake_requantize_min_node = match.inputs[3].node; + const NodeDef& fake_requantize_max_node = match.inputs[4].node; + + new_nodes->push_back(original_op_node); + new_nodes->push_back(fake_requantize_min_node); + new_nodes->push_back(fake_requantize_max_node); + + NodeDef requantize_node; + requantize_node.CopyFrom(fake_requantize_node); + requantize_node.mutable_input()->Clear(); + AddNodeInput(original_op_node.name() + ":0", &requantize_node); + AddNodeInput(original_op_node.name() + ":1", &requantize_node); + AddNodeInput(original_op_node.name() + ":2", &requantize_node); + AddNodeInput(fake_requantize_min_node.name(), &requantize_node); + AddNodeInput(fake_requantize_max_node.name(), &requantize_node); + new_nodes->push_back(requantize_node); + + return Status::OK(); + }, + {}, output_graph_def)); + + return Status::OK(); +} + +// Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of +// linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd +// op that we want to apply the trained constant conversions to. This pass tries +// to move FakeQuant ops up the input chain, so they're as close as possible to +// the 32-bit conversion, and so can be easily merged into the automatic dynamic +// Requantizes. +Status HoistFakeQuants(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef current_graph_def = input_graph_def; + const int max_depth = 3; + for (int depth = max_depth; depth > 0; --depth) { + OpTypePattern pattern = {"*"}; + for (int i = 0; i < depth; ++i) { + pattern = {"*", {pattern}}; + } + pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}}; + GraphDef hoisted_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, pattern, + [depth](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& fake_quant_node = match.node; + const NodeDef& fake_quant_min_node = match.inputs[1].node; + const NodeDef& fake_quant_max_node = match.inputs[2].node; + std::vector linear_nodes; + NodeMatch current_match = match; + for (int i = 0; i <= depth; ++i) { + linear_nodes.push_back(current_match.inputs[0].node); + current_match = current_match.inputs[0]; + } + NodeDef new_fake_quant_node; + new_fake_quant_node.CopyFrom(fake_quant_node); + new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted"); + new_fake_quant_node.set_input( + 0, linear_nodes[linear_nodes.size() - 2].input(0)); + new_nodes->push_back(new_fake_quant_node); + + new_nodes->push_back(fake_quant_min_node); + new_nodes->push_back(fake_quant_max_node); + + linear_nodes[linear_nodes.size() - 2].set_input( + 0, new_fake_quant_node.name()); + linear_nodes.front().set_name(fake_quant_node.name()); + for (const NodeDef& linear_node : linear_nodes) { + new_nodes->push_back(linear_node); + } + + return Status::OK(); + }, + {}, &hoisted_graph_def)); + current_graph_def = hoisted_graph_def; + } + *output_graph_def = current_graph_def; + + return Status::OK(); +} + +// Converts any float ops that have eight-bit equivalents into their quantized +// forms, so that as much calculation as possible is done in the lower-precision +// format. +Status QuantizeNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + // Loop through all of the quantizable op types, and replace any occurrences + // with equivalent sub-graphs with quantized ops at their core. For example + // this one-input operation: + // + // Input(float) + // | + // v + // Operation + // | + // v + // (float) + // + // Will be turned into it's quantized equivalent: + // + // Input(float) ReshapeDims + // +------v v-------------+ + // | Reshape + // | | + // | | ReductionDims + // | +-----+ | + // | | +---c---------+ + // | v v v v-------+ + // | Min Max + // | +----+ | + // v v v--------+ + // Quantize + // | + // v + // QuantizedOperation + // | | | + // v v v + // Dequantize + // | + // v + // (float) + // + // This keeps the inputs and outputs visible to the rest of the graph in + // float + // and converts them down to quantized buffers internally for the + // computation. + // The result will end up with a lot of redundant dequantize/quantize pairs + // between adjacent quantized ops, but a later pass removes these where it + // can. + const std::vector& op_list = GetQuantizedOpList(); + string op_pattern; + bool is_first = true; + std::map op_map; + for (const QuantizedOpInfo& op_info : op_list) { + strings::StrAppend(&op_pattern, (is_first ? "" : "|"), op_info.float_name); + op_map.insert({op_info.float_name, op_info}); + is_first = false; + } + + // If input_min and input max have been passed in, then we convert all float + // Placeholder nodes into quantized versions, with the supplied values as + // their range. + GraphDef placeholder_graph_def; + TF_RETURN_IF_ERROR( + QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def)); + + // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear + // operations like Relu or MaxPool, move them up so that they're as close as + // possible to ops with 32-bit outputs like BiasAdd or Conv2D. + GraphDef hoisted_graph_def; + TF_RETURN_IF_ERROR( + HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def)); + + // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of + // activation layers, into Requantize ops with those ranges instead. This + // makes it easier to replace the dynamic range calculations that are used + // by default. + GraphDef converted_graph_def; + TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context, + &converted_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def)); + + // If fallback_min and fallback_max are set, then we'll use hardwired ranges + // for all the 32-bit to 8-bit requantizations. + float fallback_min; + float fallback_max; + bool has_fallback_range; + TF_RETURN_IF_ERROR(ExtractRangeFromParams( + context, "fallback_min", "fallback_max", &fallback_min, &fallback_max, + &has_fallback_range)); + + // Replace all occurrences of the current float op with its quantized + // equivalent. + GraphDef quantized_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + converted_graph_def, {op_pattern}, + [&op_map, fallback_min, fallback_max, has_fallback_range]( + const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& float_node = match.node; + const QuantizedOpInfo& op_info = op_map[float_node.op()]; + + string namespace_prefix = float_node.name() + "_eightbit"; + + // Quantize all of the inputs. + std::vector quantized_input_names; + for (int i = 0; i < float_node.input_size(); ++i) { + // Skip any non-float inputs. + if (op_info.unquantized_inputs.count(i)) { + continue; + } + + const string& input_name = float_node.input(i); + string unique_input_name = + namespace_prefix + "/" + UniqueNodeNameFromInput(input_name); + + // Add some common constants we need for reshaping inputs. + NodeDef reshape_dims; + reshape_dims.set_op("Const"); + reshape_dims.set_name(unique_input_name + "/reshape_dims"); + SetNodeAttr("dtype", DT_INT32, &reshape_dims); + Tensor reshape_dims_tensor(DT_INT32, {1}); + reshape_dims_tensor.flat()(0) = -1; + SetNodeTensorAttr("value", reshape_dims_tensor, &reshape_dims); + new_nodes->push_back(reshape_dims); + + NodeDef reduction_dims; + reduction_dims.set_op("Const"); + reduction_dims.set_name(unique_input_name + "/reduction_dims"); + SetNodeAttr("dtype", DT_INT32, &reduction_dims); + Tensor reduction_dims_tensor(DT_INT32, {1}); + reduction_dims_tensor.flat()(0) = 0; + SetNodeTensorAttr("value", reduction_dims_tensor, + &reduction_dims); + new_nodes->push_back(reduction_dims); + + NodeDef reshape_node; + reshape_node.set_op("Reshape"); + reshape_node.set_name(unique_input_name + "/reshape"); + SetNodeAttr("T", DT_FLOAT, &reshape_node); + AddNodeInput(input_name, &reshape_node); + AddNodeInput(reshape_dims.name(), &reshape_node); + new_nodes->push_back(reshape_node); + + NodeDef min_node; + min_node.set_op("Min"); + min_node.set_name(unique_input_name + "/min"); + SetNodeAttr("T", DT_FLOAT, &min_node); + SetNodeAttr("keep_dims", false, &min_node); + AddNodeInput(reshape_node.name(), &min_node); + AddNodeInput(reduction_dims.name(), &min_node); + new_nodes->push_back(min_node); + + NodeDef max_node; + max_node.set_op("Max"); + max_node.set_name(unique_input_name + "/max"); + SetNodeAttr("T", DT_FLOAT, &max_node); + SetNodeAttr("keep_dims", false, &max_node); + AddNodeInput(reshape_node.name(), &max_node); + AddNodeInput(reduction_dims.name(), &max_node); + new_nodes->push_back(max_node); + + NodeDef quantize_node; + quantize_node.set_op("QuantizeV2"); + quantize_node.set_name(unique_input_name + "/quantize"); + SetNodeAttr("T", DT_QUINT8, &quantize_node); + SetNodeAttr("mode", "MIN_FIRST", &quantize_node); + AddNodeInput(input_name, &quantize_node); + AddNodeInput(min_node.name(), &quantize_node); + AddNodeInput(max_node.name(), &quantize_node); + new_nodes->push_back(quantize_node); + quantized_input_names.push_back(quantize_node.name()); + } + + // Set up the quantized version of the current op. + NodeDef quantized_main_node; + quantized_main_node.set_op("Quantized" + float_node.op()); + quantized_main_node.set_name(float_node.name() + "/eightbit"); + for (const string& attr_to_copy : op_info.attrs_to_copy) { + CopyNodeAttr(float_node, attr_to_copy, attr_to_copy, + &quantized_main_node); + } + for (const std::pair& dtype_to_set : + op_info.dtypes_to_set) { + SetNodeAttr(dtype_to_set.first, dtype_to_set.second, + &quantized_main_node); + } + int quantized_input_index = 0; + for (int i = 0; i < float_node.input_size(); ++i) { + if (op_info.unquantized_inputs.count(i)) { + AddNodeInput(float_node.input(i), &quantized_main_node); + } else { + const string& quantized_input_name = + quantized_input_names[quantized_input_index]; + AddNodeInput(quantized_input_name + ":0", &quantized_main_node); + ++quantized_input_index; + } + } + if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) { + for (const string& quantized_input_name : quantized_input_names) { + AddNodeInput(quantized_input_name + ":1", &quantized_main_node); + AddNodeInput(quantized_input_name + ":2", &quantized_main_node); + } + } else { + for (const string& quantized_input_name : quantized_input_names) { + AddNodeInput(quantized_input_name + ":1", &quantized_main_node); + } + for (const string& quantized_input_name : quantized_input_names) { + AddNodeInput(quantized_input_name + ":2", &quantized_main_node); + } + } + new_nodes->push_back(quantized_main_node); + + string eight_bit_node_name; + if (op_info.output_bit_depth == DT_QINT32) { + // Shrink the range of the output down from 32 bits to 8. + string requantize_min_input; + string requantize_max_input; + if (has_fallback_range) { + // Use constant values for the min/max range if they were given. + NodeDef fallback_min_node; + fallback_min_node.set_op("Const"); + fallback_min_node.set_name(quantized_main_node.name() + + "/fallback_min"); + SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node); + Tensor fallback_min_tensor(DT_FLOAT, {}); + fallback_min_tensor.flat()(0) = fallback_min; + SetNodeTensorAttr("value", fallback_min_tensor, + &fallback_min_node); + new_nodes->push_back(fallback_min_node); + + NodeDef fallback_max_node; + fallback_max_node.set_op("Const"); + fallback_max_node.set_name(quantized_main_node.name() + + "/fallback_max"); + SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node); + Tensor fallback_max_tensor(DT_FLOAT, {}); + fallback_max_tensor.flat()(0) = fallback_max; + SetNodeTensorAttr("value", fallback_max_tensor, + &fallback_max_node); + new_nodes->push_back(fallback_max_node); + + requantize_min_input = fallback_min_node.name(); + requantize_max_input = fallback_max_node.name(); + } else { + // Otherwise dynamically measure the range each time. + NodeDef requant_range_node; + requant_range_node.set_op("RequantizationRange"); + requant_range_node.set_name(quantized_main_node.name() + + "/requant_range"); + SetNodeAttr("Tinput", DT_QINT32, &requant_range_node); + AddNodeInput(quantized_main_node.name() + ":0", + &requant_range_node); + AddNodeInput(quantized_main_node.name() + ":1", + &requant_range_node); + AddNodeInput(quantized_main_node.name() + ":2", + &requant_range_node); + new_nodes->push_back(requant_range_node); + + requantize_min_input = requant_range_node.name() + ":0"; + requantize_max_input = requant_range_node.name() + ":1"; + } + NodeDef requantize_node; + requantize_node.set_op("Requantize"); + requantize_node.set_name(quantized_main_node.name() + "/requantize"); + SetNodeAttr("Tinput", DT_QINT32, &requantize_node); + SetNodeAttr("out_type", DT_QUINT8, &requantize_node); + AddNodeInput(quantized_main_node.name() + ":0", &requantize_node); + AddNodeInput(quantized_main_node.name() + ":1", &requantize_node); + AddNodeInput(quantized_main_node.name() + ":2", &requantize_node); + AddNodeInput(requantize_min_input, &requantize_node); + AddNodeInput(requantize_max_input, &requantize_node); + new_nodes->push_back(requantize_node); + eight_bit_node_name = requantize_node.name(); + } else { + eight_bit_node_name = quantized_main_node.name(); + } + + // Convert the 8-bit result back into float for the final output. + NodeDef dequantize_node; + dequantize_node.set_op("Dequantize"); + dequantize_node.set_name(float_node.name()); + SetNodeAttr("T", DT_QUINT8, &dequantize_node); + SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); + AddNodeInput(eight_bit_node_name + ":0", &dequantize_node); + AddNodeInput(eight_bit_node_name + ":1", &dequantize_node); + AddNodeInput(eight_bit_node_name + ":2", &dequantize_node); + new_nodes->push_back(dequantize_node); + + return Status::OK(); + }, + {}, &quantized_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def)); + + // If we've ended up with two Requantize ops in a row (for example if there + // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together, + // using the trained range from the second op. + GraphDef merged_graph_def; + TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context, + &merged_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def)); + + // There can be duplicate quantize nodes if multiple ops pull from a single + // input, which makes it harder to remove redundant ones, so strip them out. + GraphDef deduped_graph_def; + TF_RETURN_IF_ERROR( + MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def)); + + // Look for Dequantizes that immediately go into Quantizes, and remove them + // since the two together cancel each other out. This allows us to keep the + // data flow in eight bit where two adjacent ops are in eight bit, but still + // keep interoperability with float ops. + TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context, + output_graph_def)); + TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def)); + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes); + +REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc new file mode 100644 index 00000000000..a82bf781fc6 --- /dev/null +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -0,0 +1,1321 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status QuantizeNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status QuantizePlaceholders(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status HoistFakeQuants(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +Status MergeDuplicateNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class QuantizeNodesTest : public ::testing::Test { + protected: + void TestTransformedVersusFloatGraph( + const TransformFunc& transform_function, const GraphDef& float_graph_def, + const std::vector>& float_inputs, + const std::vector>& transformed_inputs, + const std::vector& output_names, + const TransformFuncContext& in_context, double threshold, + GraphDef* transformed_graph_def) { + std::unique_ptr float_session(NewSession(SessionOptions())); + TF_ASSERT_OK(float_session->Create(float_graph_def)); + std::vector float_outputs; + TF_ASSERT_OK( + float_session->Run(float_inputs, output_names, {}, &float_outputs)); + + TransformFuncContext context(in_context); + std::vector input_names; + for (const std::pair float_input : + float_inputs) { + context.input_names.push_back(float_input.first); + } + + context.output_names = output_names; + TF_ASSERT_OK( + transform_function(float_graph_def, context, transformed_graph_def)); + + std::unique_ptr transformed_session(NewSession(SessionOptions())); + TF_ASSERT_OK(transformed_session->Create(*transformed_graph_def)); + std::vector transformed_outputs; + TF_ASSERT_OK(transformed_session->Run(transformed_inputs, output_names, {}, + &transformed_outputs)); + + const int output_count = output_names.size(); + EXPECT_EQ(output_count, float_outputs.size()); + EXPECT_EQ(output_count, transformed_outputs.size()); + for (int i = 0; i < output_count; ++i) { + test::ExpectTensorNear(float_outputs[i], transformed_outputs[i], + threshold); + } + } + + void TestQuantizedVersusFloatGraph( + const GraphDef& float_graph_def, + const std::vector>& inputs, + const std::vector& output_names) { + GraphDef quantized_graph_def; + TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def, inputs, + inputs, output_names, {}, 1.0, + &quantized_graph_def); + // Reshape is not included here because it can be added as part of the + // quantization process. + const std::set quantizable_ops = { + "BiasAdd", "Concat", "Conv2D", "MatMul", + "Relu", "Relu6", "AvgPool", "MaxPool", + }; + for (const NodeDef& node : quantized_graph_def.node()) { + EXPECT_EQ(0, quantizable_ops.count(node.op())); + } + } + + void TestGraphWithInputRange( + const GraphDef& float_graph_def, + const std::vector>& float_inputs, + const std::vector& output_names, float range_min, + float range_max) { + TransformFuncContext context; + context.params["input_min"] = {strings::StrCat(range_min)}; + context.params["input_max"] = {strings::StrCat(range_max)}; + + std::vector> quantized_inputs; + for (const std::pair& float_input : float_inputs) { + const Tensor& float_tensor = float_input.second; + Tensor quantized_tensor(DT_QUINT8, float_tensor.shape()); + FloatTensorToQuantizedInPlace(float_tensor, range_min, range_max, + &quantized_tensor); + quantized_inputs.push_back({float_input.first, quantized_tensor}); + } + + GraphDef quantized_graph_def; + TestTransformedVersusFloatGraph( + QuantizeNodes, float_graph_def, float_inputs, quantized_inputs, + output_names, context, 1.0, &quantized_graph_def); + } + + void TestGraphWithFallbackRange( + const GraphDef& float_graph_def, + const std::vector>& float_inputs, + const std::vector& output_names, float range_min, float range_max, + GraphDef* quantized_graph_def) { + TransformFuncContext context; + context.params["fallback_min"] = {strings::StrCat(range_min)}; + context.params["fallback_max"] = {strings::StrCat(range_max)}; + TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def, + float_inputs, float_inputs, output_names, + context, 2.0, quantized_graph_def); + } + + void TestQuantizeMatMul(int m, int n, int k, + const std::vector& a_values, + const std::vector& b_values) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor a_tensor(DT_FLOAT, TensorShape({m, k})); + test::FillValues(&a_tensor, a_values); + Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor)); + + Tensor b_tensor(DT_FLOAT, TensorShape({k, n})); + test::FillValues(&b_tensor, b_values); + Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor)); + + Output mat_mul_op = MatMul(root.WithOpName("mat_mul_op"), a_op, b_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"mat_mul_op"}); + } + + void TestQuantizeMatMulTiny() { + // These tests are added to test the generate case where + // min(matrix) == max(matrix), which used to cause problems. + TestQuantizeMatMul(1, 1, 1, {2}, {3}); + TestQuantizeMatMul(1, 2, 1, {1}, {2, 3}); + TestQuantizeMatMul(1, 1, 2, {1, 1}, {1, 1}); + TestQuantizeMatMul(1, 1, 2, {0, 0}, {1, 1}); + // The general case. + TestQuantizeMatMul(1, 1, 2, {1, 2}, {1, 2}); + } + + void TestQuantizeMatMulSmall() { + TestQuantizeMatMul(2, 4, 3, {1, 2, 3, 4, 5, 6}, + {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + } + + void TestQuantizeConv2D(int depth, int input_width, int input_height, + int input_batch_count, int filter_size, + int filter_count, int stride, const string& padding, + const std::vector& input_values, + const std::vector& filter_values) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_tensor(DT_FLOAT, TensorShape({input_batch_count, input_height, + input_width, depth})); + test::FillValues(&input_tensor, input_values); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_tensor)); + + Tensor filter_tensor( + DT_FLOAT, TensorShape({filter_size, filter_size, depth, filter_count})); + test::FillValues(&filter_tensor, filter_values); + Output filter_op = + Const(root.WithOpName("filter_op"), Input::Initializer(filter_tensor)); + + Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, filter_op, + {1, stride, stride, 1}, padding); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"conv_op"}); + } + + void TestQuantizeBiasAdd() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6})); + test::FillIota(&input_tensor, 1); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_tensor)); + + Tensor offset_tensor(DT_FLOAT, TensorShape({6})); + test::FillIota(&offset_tensor, 1); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Output bias_add_op = + BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"bias_add_op"}); + } + + void TestQuantizeConcat() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor shape_tensor(DT_INT32, TensorShape({})); + test::FillValues(&shape_tensor, {0}); + Output shape_op = + Const(root.WithOpName("shape_op"), Input::Initializer(shape_tensor)); + + Tensor a_tensor(DT_FLOAT, TensorShape({2, 2, 3})); + test::FillValues(&a_tensor, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor)); + + Tensor b_tensor(DT_FLOAT, TensorShape({2, 2, 3})); + test::FillValues(&b_tensor, + {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor)); + + Output concat_op = + Concat(root.WithOpName("concat_op"), shape_op, {a_op, b_op}); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"concat_op"}); + } + + void TestQuantizeRelu() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1})); + test::FillValues(&constant_tensor, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Output constant_op = Const(root.WithOpName("constant_op"), + Input::Initializer(constant_tensor)); + + Output relu_op = Relu(root.WithOpName("relu_op"), constant_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"relu_op"}); + } + + void TestQuantizeRelu6() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1})); + test::FillValues(&constant_tensor, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Output constant_op = Const(root.WithOpName("constant_op"), + Input::Initializer(constant_tensor)); + + Output relu6_op = Relu6(root.WithOpName("relu6_op"), constant_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"relu6_op"}); + } + + void TestQuantizeMaxPool() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1})); + test::FillValues(&constant_tensor, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Output constant_op = Const(root.WithOpName("constant_op"), + Input::Initializer(constant_tensor)); + + Output max_pool_op = MaxPool(root.WithOpName("max_pool_op"), constant_op, + {1, 2, 2, 1}, {1, 1, 1, 1}, "SAME"); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"max_pool_op"}); + } + + void TestQuantizeAvgPool() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor constant_tensor(DT_FLOAT, TensorShape({1, 2, 6, 1})); + test::FillValues(&constant_tensor, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Output constant_op = Const(root.WithOpName("constant_op"), + Input::Initializer(constant_tensor)); + + Output avg_pool_op = AvgPool(root.WithOpName("avg_pool_op"), constant_op, + {1, 2, 2, 1}, {1, 1, 1, 1}, "SAME"); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"avg_pool_op"}); + } + + void TestQuantizeReshape() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor constant_tensor(DT_FLOAT, TensorShape({4, 5})); + test::FillValues(&constant_tensor, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); + Output constant_op = Const(root.WithOpName("constant_op"), + Input::Initializer(constant_tensor)); + + Output reshape_op = + Reshape(root.WithOpName("reshape_op"), constant_op, {10, 2}); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TestQuantizedVersusFloatGraph(float_graph_def, {}, {"reshape_op"}); + } + + void TestRemoveRedundantQuantization() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor quantized_tensor(DT_QUINT8, TensorShape({})); + test::FillValues(&quantized_tensor, {0}); + Output quantized_op = Const(root.WithOpName("quantized_op"), + Input::Initializer(quantized_tensor)); + + Tensor quantized_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_min_tensor, {2.0f}); + Output quantized_min_op = Const(root.WithOpName("quantized_min_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor quantized_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_max_tensor, {2.0f}); + Output quantized_max_op = Const(root.WithOpName("quantized_max_op"), + Input::Initializer(quantized_min_tensor)); + + Output dequantize_op = + Dequantize(root.WithOpName("dequantize_op"), quantized_op, + quantized_min_op, quantized_max_op); + + Tensor dequantize_reshape_dims_tensor(DT_INT32, TensorShape({1})); + test::FillValues(&dequantize_reshape_dims_tensor, {-1}); + Output dequantize_reshape_dims = + Const(root.WithOpName("dequantize_reshape_dims"), + Input::Initializer(dequantize_reshape_dims_tensor)); + + Tensor dequantize_reduction_dims_tensor(DT_INT32, TensorShape({})); + test::FillValues(&dequantize_reduction_dims_tensor, {0}); + Output dequantize_reduction_dims = + Const(root.WithOpName("dequantize_reduction_dims"), + Input::Initializer(dequantize_reduction_dims_tensor)); + + Output dequantize_reshape = Reshape(root.WithOpName("dequantize_reshape"), + dequantize_op, dequantize_reshape_dims); + + Output dequantize_min = + Min(root.WithOpName("dequantize_min"), dequantize_reshape, + dequantize_reduction_dims, Min::Attrs().KeepDims(false)); + + Output dequantize_max = + Max(root.WithOpName("dequantize_max"), dequantize_reshape, + dequantize_reduction_dims, Max::Attrs().KeepDims(false)); + + QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op, + dequantize_min, dequantize_max, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Output final_dequantize = + Dequantize(root.WithOpName("final_dequantize"), quantize_op.output, + quantize_op.output_min, quantize_op.output_max); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef removed_graph_def; + TestTransformedVersusFloatGraph( + RemoveRedundantQuantizations, float_graph_def, {}, {}, + {"final_dequantize"}, {}, 1.0, &removed_graph_def); + + std::map node_map; + MapNamesToNodes(removed_graph_def, &node_map); + EXPECT_EQ(1, node_map.count("final_dequantize")); + EXPECT_EQ("quantized_op", node_map.at("final_dequantize")->input(0)); + } + + void TestRemoveRedundantQuantizationWithBiasAdd() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6})); + test::FillValues(&quantized_tensor, {0, 0, 0, 0, 0, 0}); + Output quantized_op = Const(root.WithOpName("quantized_op"), + Input::Initializer(quantized_tensor)); + + Tensor quantized_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_min_tensor, {2.0f}); + Output quantized_min_op = Const(root.WithOpName("quantized_min_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor quantized_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_max_tensor, {2.0f}); + Output quantized_max_op = Const(root.WithOpName("quantized_max_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor offset_tensor(DT_QUINT8, TensorShape({6})); + test::FillValues(&offset_tensor, {1, 2, 3, 4, 5, 6}); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Tensor offset_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_min_tensor, {0.0f}); + Output offset_min_op = Const(root.WithOpName("offset_min_op"), + Input::Initializer(offset_min_tensor)); + + Tensor offset_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_max_tensor, {255.0f}); + Output offset_max_op = Const(root.WithOpName("offset_max_op"), + Input::Initializer(offset_max_tensor)); + + QuantizedBiasAdd quantized_bias_add_op( + root.WithOpName("bias_add_op"), quantized_op, offset_op, + quantized_min_op, quantized_max_op, offset_min_op, offset_max_op, + DT_QINT32); + + RequantizationRange requantization_range_op( + root.WithOpName("requantization_range_op"), + quantized_bias_add_op.output, quantized_bias_add_op.min_out, + quantized_bias_add_op.max_out); + + Requantize requantize_op( + root.WithOpName("requantize_op"), quantized_bias_add_op.output, + quantized_bias_add_op.min_out, quantized_bias_add_op.max_out, + requantization_range_op.output_min, requantization_range_op.output_max, + DT_QUINT8); + + Output dequantize_op = + Dequantize(root.WithOpName("dequantize_op"), requantize_op.output, + requantize_op.output_min, requantize_op.output_max); + + Tensor dequantize_reshape_dims_tensor(DT_INT32, TensorShape({1})); + test::FillValues(&dequantize_reshape_dims_tensor, {-1}); + Output dequantize_reshape_dims = + Const(root.WithOpName("dequantize_reshape_dims"), + Input::Initializer(dequantize_reshape_dims_tensor)); + + Tensor dequantize_reduction_dims_tensor(DT_INT32, TensorShape({})); + test::FillValues(&dequantize_reduction_dims_tensor, {0}); + Output dequantize_reduction_dims = + Const(root.WithOpName("dequantize_reduction_dims"), + Input::Initializer(dequantize_reduction_dims_tensor)); + + Output dequantize_reshape = Reshape(root.WithOpName("dequantize_reshape"), + dequantize_op, dequantize_reshape_dims); + + Output dequantize_min = + Min(root.WithOpName("dequantize_min"), dequantize_reshape, + dequantize_reduction_dims, Min::Attrs().KeepDims(false)); + + Output dequantize_max = + Max(root.WithOpName("dequantize_max"), dequantize_reshape, + dequantize_reduction_dims, Max::Attrs().KeepDims(false)); + + QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op, + dequantize_min, dequantize_max, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Output final_dequantize = + Dequantize(root.WithOpName("final_dequantize"), quantize_op.output, + quantize_op.output_min, quantize_op.output_max); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef removed_graph_def; + TestTransformedVersusFloatGraph( + RemoveRedundantQuantizations, float_graph_def, {}, {}, + {"final_dequantize"}, {}, 1.0, &removed_graph_def); + + std::map node_map; + MapNamesToNodes(removed_graph_def, &node_map); + EXPECT_EQ(1, node_map.count("final_dequantize")); + EXPECT_EQ("requantize_op", node_map.at("final_dequantize")->input(0)); + } + + void TestRemoveRedundantQuantizationWithMultipleOutputs() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor quantized_tensor(DT_QUINT8, TensorShape({1, 6})); + test::FillValues(&quantized_tensor, {0, 0, 0, 0, 0, 0}); + Output quantized_op = Const(root.WithOpName("quantized_op"), + Input::Initializer(quantized_tensor)); + + Tensor quantized_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_min_tensor, {2.0f}); + Output quantized_min_op = Const(root.WithOpName("quantized_min_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor quantized_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_max_tensor, {2.0f}); + Output quantized_max_op = Const(root.WithOpName("quantized_max_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor offset_tensor(DT_QUINT8, TensorShape({6})); + test::FillValues(&offset_tensor, {1, 2, 3, 4, 5, 6}); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Tensor offset_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_min_tensor, {0.0f}); + Output offset_min_op = Const(root.WithOpName("offset_min_op"), + Input::Initializer(offset_min_tensor)); + + Tensor offset_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_max_tensor, {255.0f}); + Output offset_max_op = Const(root.WithOpName("offset_max_op"), + Input::Initializer(offset_max_tensor)); + + QuantizedBiasAdd quantized_bias_add_op( + root.WithOpName("bias_add_op"), quantized_op, offset_op, + quantized_min_op, quantized_max_op, offset_min_op, offset_max_op, + DT_QINT32); + + RequantizationRange requantization_range_op( + root.WithOpName("requantization_range_op"), + quantized_bias_add_op.output, quantized_bias_add_op.min_out, + quantized_bias_add_op.max_out); + + Requantize requantize_op( + root.WithOpName("requantize_op"), quantized_bias_add_op.output, + quantized_bias_add_op.min_out, quantized_bias_add_op.max_out, + requantization_range_op.output_min, requantization_range_op.output_max, + DT_QUINT8); + + Output dequantize_op = + Dequantize(root.WithOpName("dequantize_op"), requantize_op.output, + requantize_op.output_min, requantize_op.output_max); + + Tensor dequantize_reshape_dims_tensor(DT_INT32, TensorShape({1})); + test::FillValues(&dequantize_reshape_dims_tensor, {-1}); + Output dequantize_reshape_dims = + Const(root.WithOpName("dequantize_reshape_dims"), + Input::Initializer(dequantize_reshape_dims_tensor)); + + Tensor dequantize_reduction_dims_tensor(DT_INT32, TensorShape({})); + test::FillValues(&dequantize_reduction_dims_tensor, {0}); + Output dequantize_reduction_dims = + Const(root.WithOpName("dequantize_reduction_dims"), + Input::Initializer(dequantize_reduction_dims_tensor)); + + Output dequantize_reshape = Reshape(root.WithOpName("dequantize_reshape"), + dequantize_op, dequantize_reshape_dims); + + Output dequantize_min = + Min(root.WithOpName("dequantize_min"), dequantize_reshape, + dequantize_reduction_dims, Min::Attrs().KeepDims(false)); + + Output dequantize_max = + Max(root.WithOpName("dequantize_max"), dequantize_reshape, + dequantize_reduction_dims, Max::Attrs().KeepDims(false)); + + QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op, + dequantize_min, dequantize_max, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Output final_dequantize = + Dequantize(root.WithOpName("final_dequantize"), quantize_op.output, + quantize_op.output_min, quantize_op.output_max); + + Output relu_op = Relu(root.WithOpName("relu_op"), dequantize_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef removed_graph_def; + TestTransformedVersusFloatGraph( + RemoveRedundantQuantizations, float_graph_def, {}, {}, + {"final_dequantize", "relu_op"}, {}, 1.0, &removed_graph_def); + + std::map op_type_count; + for (const NodeDef& node : removed_graph_def.node()) { + ++op_type_count[node.op()]; + } + EXPECT_EQ(2, op_type_count["Dequantize"]); + } + + void TestQuantizePlaceholders() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Output placeholder_op = + Placeholder(root.WithOpName("placeholder_op"), DT_FLOAT); + + Output relu_op = Relu(root.WithOpName("relu_op"), placeholder_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + TransformFuncContext context; + context.input_names = {"placeholder_op"}; + context.output_names = {"relu_op"}; + context.params = {{"input_min", {"-10.0"}}, {"input_max", {"10.0"}}}; + + GraphDef quantized_graph_def; + TF_ASSERT_OK( + QuantizePlaceholders(float_graph_def, context, &quantized_graph_def)); + + Tensor input_tensor(DT_FLOAT, {}); + input_tensor.flat()(0) = 5.0f; + + TestQuantizedVersusFloatGraph( + float_graph_def, {{"placeholder_op", input_tensor}}, {"relu_op"}); + + std::map node_map; + MapNamesToNodes(quantized_graph_def, &node_map); + EXPECT_NE("placeholder_op", node_map.at("relu_op")->input(0)); + EXPECT_EQ("Placeholder", node_map.at("placeholder_op")->op()); + EXPECT_EQ(DT_QUINT8, + node_map.at("placeholder_op")->attr().at("dtype").type()); + } + + void TestInputRange() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({1, width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output bias_add = + BiasAdd(root.WithOpName("bias_add"), a_const, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&placeholder_tensor, 1.0f); + + TestGraphWithInputRange(graph_def, {{"placeholder", placeholder_tensor}}, + {"bias_add"}, 0.0f, 100.0f); + } + + void TestFallbackRange() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({1, width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output bias_add = + BiasAdd(root.WithOpName("bias_add"), a_const, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&placeholder_tensor, 1.0f); + + GraphDef quantized_graph_def; + TestGraphWithFallbackRange(graph_def, {{"placeholder", placeholder_tensor}}, + {"bias_add"}, 0.0f, 200.0f, + &quantized_graph_def); + + for (const NodeDef& node : quantized_graph_def.node()) { + EXPECT_NE("RequantizationRange", node.op()); + } + } + + void TestConvertFakeQuantsToRequantize() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6})); + test::FillIota(&input_tensor, 1); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_tensor)); + + Tensor offset_tensor(DT_FLOAT, TensorShape({6})); + test::FillIota(&offset_tensor, 1); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Output bias_add_op = + BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op); + + Tensor fake_quant_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_quant_min_tensor, {0.0f}); + Output fake_quant_min_op = Const(root.WithOpName("fake_quant_min_op"), + Input::Initializer(fake_quant_min_tensor)); + + Tensor fake_quant_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_quant_max_tensor, {18.0f}); + Output fake_quant_max_op = Const(root.WithOpName("fake_quant_max_op"), + Input::Initializer(fake_quant_max_tensor)); + + Output fake_quant_op = + FakeQuantWithMinMaxVars(root.WithOpName("fake_quant_op"), bias_add_op, + fake_quant_min_op, fake_quant_max_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef converted_graph_def; + TestTransformedVersusFloatGraph(ConvertFakeQuantsToRequantize, + float_graph_def, {}, {}, {"fake_quant_op"}, + {}, 1.0, &converted_graph_def); + + for (const NodeDef& node : converted_graph_def.node()) { + EXPECT_NE("FakeQuantWithMinMaxVars", node.op()); + } + } + + void TestMergeAdjacentRequantizes() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_tensor(DT_QUINT8, TensorShape({1, 1, 2, 6})); + test::FillValues(&input_tensor, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_tensor)); + + Tensor input_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&input_min_tensor, {0.0f}); + Output input_min_op = Const(root.WithOpName("input_min_op"), + Input::Initializer(input_min_tensor)); + + Tensor input_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&input_max_tensor, {255.0f}); + Output input_max_op = Const(root.WithOpName("input_max_op"), + Input::Initializer(input_max_tensor)); + + Tensor offset_tensor(DT_QUINT8, TensorShape({6})); + test::FillValues(&offset_tensor, {1, 2, 3, 4, 5, 6}); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Tensor offset_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_min_tensor, {0.0f}); + Output offset_min_op = Const(root.WithOpName("offset_min_op"), + Input::Initializer(offset_min_tensor)); + + Tensor offset_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&offset_max_tensor, {255.0f}); + Output offset_max_op = Const(root.WithOpName("offset_max_op"), + Input::Initializer(offset_max_tensor)); + + QuantizedBiasAdd quantized_bias_add_op( + root.WithOpName("quantized_bias_add_op"), input_op, offset_op, + input_min_op, input_max_op, offset_min_op, offset_max_op, DT_QINT32); + + RequantizationRange requantization_range_op( + root.WithOpName("requantization_range_op"), + quantized_bias_add_op.output, quantized_bias_add_op.min_out, + quantized_bias_add_op.max_out); + + Requantize requantize_op( + root.WithOpName("requantize_op"), quantized_bias_add_op.output, + quantized_bias_add_op.min_out, quantized_bias_add_op.max_out, + requantization_range_op.output_min, requantization_range_op.output_max, + DT_QUINT8); + + Output dequantize_op = + Dequantize(root.WithOpName("dequantize_op"), requantize_op.output, + requantize_op.output_min, requantize_op.output_max, + Dequantize::Attrs().Mode("MIN_FIRST")); + + Tensor quantize_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantize_min_tensor, {0.0f}); + Output quantize_min_op = Const(root.WithOpName("quantize_min_op"), + Input::Initializer(quantize_min_tensor)); + + Tensor quantize_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantize_max_tensor, {255.0f}); + Output quantize_max_op = Const(root.WithOpName("quantize_max_op"), + Input::Initializer(quantize_max_tensor)); + + QuantizeV2 quantize_op(root.WithOpName("quantize_op"), dequantize_op, + quantize_min_op, quantize_max_op, DT_QINT32, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Tensor fake_requantize_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_requantize_min_tensor, {0.0f}); + Output fake_requantize_min_op = + Const(root.WithOpName("fake_requantize_min_op"), + Input::Initializer(fake_requantize_min_tensor)); + + Tensor fake_requantize_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_requantize_max_tensor, {255.0f}); + Output fake_requantize_max_op = + Const(root.WithOpName("fake_requantize_max_op"), + Input::Initializer(fake_requantize_max_tensor)); + + Requantize fake_requantize_op( + root.WithOpName("fake_requantize_op"), quantize_op.output, + quantize_op.output_min, quantize_op.output_max, fake_requantize_min_op, + fake_requantize_max_op, DT_QUINT8); + + Output fake_dequantize_op = Dequantize( + root.WithOpName("fake_dequantize_op"), fake_requantize_op.output, + fake_requantize_op.output_min, fake_requantize_op.output_max, + Dequantize::Attrs().Mode("MIN_FIRST")); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef converted_graph_def; + TestTransformedVersusFloatGraph(MergeAdjacentRequantizes, float_graph_def, + {}, {}, {"fake_dequantize_op"}, {}, 1.0, + &converted_graph_def); + + int requantize_count = 0; + for (const NodeDef& node : converted_graph_def.node()) { + if (node.op() == "Requantize") { + ++requantize_count; + } + } + EXPECT_EQ(1, requantize_count); + } + + void TestConvertFakeQuantsEndToEnd() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6})); + test::FillIota(&input_tensor, 1); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_tensor)); + + Tensor offset_tensor(DT_FLOAT, TensorShape({6})); + test::FillIota(&offset_tensor, 1); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Output bias_add_op = + BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op); + + Tensor fake_quant_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_quant_min_tensor, {0.0f}); + Output fake_quant_min_op = Const(root.WithOpName("fake_quant_min_op"), + Input::Initializer(fake_quant_min_tensor)); + + Tensor fake_quant_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_quant_max_tensor, {18.0f}); + Output fake_quant_max_op = Const(root.WithOpName("fake_quant_max_op"), + Input::Initializer(fake_quant_max_tensor)); + + Output fake_quant_op = + FakeQuantWithMinMaxVars(root.WithOpName("fake_quant_op"), bias_add_op, + fake_quant_min_op, fake_quant_max_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef converted_graph_def; + TestTransformedVersusFloatGraph(QuantizeNodes, float_graph_def, {}, {}, + {"fake_quant_op"}, {}, 1.0, + &converted_graph_def); + + int requantize_count = 0; + for (const NodeDef& node : converted_graph_def.node()) { + EXPECT_NE("FakeQuantWithMinMaxVars", node.op()); + if (node.op() == "Requantize") { + ++requantize_count; + } + } + EXPECT_EQ(1, requantize_count); + } + + void TestHoistFakeQuants() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_tensor(DT_FLOAT, TensorShape({1, 1, 2, 6})); + test::FillIota(&input_tensor, 1); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_tensor)); + + Tensor offset_tensor(DT_FLOAT, TensorShape({6})); + test::FillIota(&offset_tensor, 1); + Output offset_op = + Const(root.WithOpName("offset_op"), Input::Initializer(offset_tensor)); + + Output bias_add_op = + BiasAdd(root.WithOpName("bias_add_op"), input_op, offset_op); + + Output relu_op = Relu(root.WithOpName("relu_op"), bias_add_op); + + Output max_pool_op = MaxPool(root.WithOpName("max_pool_op"), relu_op, + {1, 2, 2, 1}, {1, 1, 1, 1}, "SAME"); + + Tensor fake_quant_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_quant_min_tensor, {0.0f}); + Output fake_quant_min_op = Const(root.WithOpName("fake_quant_min_op"), + Input::Initializer(fake_quant_min_tensor)); + + Tensor fake_quant_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&fake_quant_max_tensor, {18.0f}); + Output fake_quant_max_op = Const(root.WithOpName("fake_quant_max_op"), + Input::Initializer(fake_quant_max_tensor)); + + Output fake_quant_op = + FakeQuantWithMinMaxVars(root.WithOpName("fake_quant_op"), max_pool_op, + fake_quant_min_op, fake_quant_max_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef converted_graph_def; + TestTransformedVersusFloatGraph(HoistFakeQuants, float_graph_def, {}, {}, + {"fake_quant_op"}, {}, 1.0, + &converted_graph_def); + + std::map node_map; + MapNamesToNodes(converted_graph_def, &node_map); + EXPECT_EQ("MaxPool", node_map.at("fake_quant_op")->op()); + EXPECT_EQ("FakeQuantWithMinMaxVars", + node_map.at(node_map.at("relu_op")->input(0))->op()); + } + + void TestMergeDuplicateQuantizes() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor quantized_tensor(DT_QUINT8, TensorShape({})); + test::FillValues(&quantized_tensor, {0}); + Output quantized_op = Const(root.WithOpName("quantized_op"), + Input::Initializer(quantized_tensor)); + + Tensor quantized_min_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_min_tensor, {2.0f}); + Output quantized_min_op = Const(root.WithOpName("quantized_min_op"), + Input::Initializer(quantized_min_tensor)); + + Tensor quantized_max_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&quantized_max_tensor, {2.0f}); + Output quantized_max_op = Const(root.WithOpName("quantized_max_op"), + Input::Initializer(quantized_min_tensor)); + + Output dequantize_op = + Dequantize(root.WithOpName("dequantize_op"), quantized_op, + quantized_min_op, quantized_max_op); + + Tensor quantize_reshape_dims1_tensor(DT_INT32, TensorShape({1})); + test::FillValues(&quantize_reshape_dims1_tensor, {-1}); + Output quantize_reshape_dims1 = + Const(root.WithOpName("dequantize_reshape_dims1"), + Input::Initializer(quantize_reshape_dims1_tensor)); + + Tensor quantize_reduction_dims1_tensor(DT_INT32, TensorShape({})); + test::FillValues(&quantize_reduction_dims1_tensor, {0}); + Output quantize_reduction_dims1 = + Const(root.WithOpName("quantize_reduction_dims1"), + Input::Initializer(quantize_reduction_dims1_tensor)); + + Output quantize_reshape1 = Reshape(root.WithOpName("quantize_reshape1"), + dequantize_op, quantize_reshape_dims1); + + Output quantize_min1 = + Min(root.WithOpName("quantize_min1"), quantize_reshape1, + quantize_reduction_dims1, Min::Attrs().KeepDims(false)); + + Output quantize_max1 = + Max(root.WithOpName("quantize_max1"), quantize_reshape1, + quantize_reduction_dims1, Max::Attrs().KeepDims(false)); + + QuantizeV2 quantize_op1(root.WithOpName("quantize_op1"), dequantize_op, + quantize_min1, quantize_max1, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Tensor quantize_reshape_dims2_tensor(DT_INT32, TensorShape({1})); + test::FillValues(&quantize_reshape_dims2_tensor, {-1}); + Output quantize_reshape_dims2 = + Const(root.WithOpName("dequantize_reshape_dims2"), + Input::Initializer(quantize_reshape_dims2_tensor)); + + Tensor quantize_reduction_dims2_tensor(DT_INT32, TensorShape({})); + test::FillValues(&quantize_reduction_dims2_tensor, {0}); + Output quantize_reduction_dims2 = + Const(root.WithOpName("quantize_reduction_dims2"), + Input::Initializer(quantize_reduction_dims2_tensor)); + + Output quantize_reshape2 = Reshape(root.WithOpName("quantize_reshape2"), + dequantize_op, quantize_reshape_dims2); + + Output quantize_min2 = + Min(root.WithOpName("quantize_min2"), quantize_reshape2, + quantize_reduction_dims2, Min::Attrs().KeepDims(false)); + + Output quantize_max2 = + Max(root.WithOpName("quantize_max2"), quantize_reshape2, + quantize_reduction_dims2, Max::Attrs().KeepDims(false)); + + QuantizeV2 quantize_op2(root.WithOpName("quantize_op2"), dequantize_op, + quantize_min1, quantize_max1, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Output final_dequantize1 = + Dequantize(root.WithOpName("final_dequantize1"), quantize_op1.output, + quantize_op1.output_min, quantize_op1.output_max); + + Output final_dequantize2 = + Dequantize(root.WithOpName("final_dequantize2"), quantize_op2.output, + quantize_op2.output_min, quantize_op2.output_max); + + Output add_op = + Add(root.WithOpName("add_op"), final_dequantize1, final_dequantize2); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef merged_graph_def; + TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, {}, + {}, {"add_op"}, {}, 1.0, &merged_graph_def); + + std::map op_map; + for (const NodeDef& node : merged_graph_def.node()) { + ++op_map[node.op()]; + } + EXPECT_EQ(1, op_map["QuantizeV2"]); + } + + void TestMergeDuplicateConsts() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_tensor, 1.0f); + Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor)); + + Tensor b_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_tensor, 1.0f); + Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor)); + + Output add_op = Add(root.WithOpName("add_op"), a_op, b_op); + + Tensor c_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&c_tensor, 2.0f); + Output c_op = Const(root.WithOpName("c_op"), Input::Initializer(c_tensor)); + + Output mul_op = Mul(root.WithOpName("mul_op"), add_op, c_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef merged_graph_def; + TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, {}, + {}, {"mul_op"}, {}, 1.0, &merged_graph_def); + + std::map node_map; + MapNamesToNodes(merged_graph_def, &node_map); + EXPECT_EQ(1, (node_map.count("a_op") + node_map.count("b_op"))); + string remaining_const; + if (node_map.count("a_op")) { + remaining_const = "a_op"; + } else { + remaining_const = "b_op"; + } + EXPECT_EQ(remaining_const, node_map["add_op"]->input(0)); + EXPECT_EQ(remaining_const, node_map["add_op"]->input(1)); + EXPECT_EQ(1, node_map.count("c_op")); + EXPECT_EQ("add_op", node_map["mul_op"]->input(0)); + EXPECT_EQ("c_op", node_map["mul_op"]->input(1)); + } + + void TestMergeDuplicatesNested() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_tensor, 1.0f); + Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor)); + + Output a_relu_op = Relu(root.WithOpName("a_relu_op"), a_op); + + Tensor b_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_tensor, 1.0f); + Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor)); + + Output b_relu_op = Relu(root.WithOpName("b_relu_op"), b_op); + + Output add_op = Add(root.WithOpName("add_op"), a_relu_op, b_relu_op); + + Tensor c_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&c_tensor, 2.0f); + Output c_op = Const(root.WithOpName("c_op"), Input::Initializer(c_tensor)); + + Output mul_op = Mul(root.WithOpName("mul_op"), add_op, c_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef merged_graph_def; + TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, {}, + {}, {"mul_op"}, {}, 1.0, &merged_graph_def); + + std::map node_map; + MapNamesToNodes(merged_graph_def, &node_map); + EXPECT_EQ(1, (node_map.count("a_op") + node_map.count("b_op"))); + EXPECT_EQ(1, (node_map.count("a_relu_op") + node_map.count("b_relu_op"))); + string remaining_relu; + if (node_map.count("a_relu_op")) { + remaining_relu = "a_relu_op"; + } else { + remaining_relu = "b_relu_op"; + } + EXPECT_EQ(remaining_relu, node_map["add_op"]->input(0)); + EXPECT_EQ(remaining_relu, node_map["add_op"]->input(1)); + EXPECT_EQ(1, node_map.count("c_op")); + EXPECT_EQ("add_op", node_map["mul_op"]->input(0)); + EXPECT_EQ("c_op", node_map["mul_op"]->input(1)); + } + + void TestMergeDuplicatesInOut() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_tensor, 1.0f); + Output a_op = Const(root.WithOpName("a_op"), Input::Initializer(a_tensor)); + + Output a_relu_op = Relu(root.WithOpName("a_relu_op"), a_op); + + Tensor b_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_tensor, 1.0f); + Output b_op = Const(root.WithOpName("b_op"), Input::Initializer(b_tensor)); + + Output b_relu_op = Relu(root.WithOpName("b_relu_op"), b_op); + + Output add_op = Add(root.WithOpName("add_op"), a_relu_op, b_relu_op); + + Tensor c_tensor(DT_FLOAT, TensorShape({width})); + test::FillIota(&c_tensor, 2.0f); + Output c_op = Const(root.WithOpName("c_op"), Input::Initializer(c_tensor)); + + Output mul_op1 = Mul(root.WithOpName("mul_op1"), add_op, c_op); + Output mul_op2 = Mul(root.WithOpName("mul_op2"), add_op, c_op); + Output mul_op3 = Mul(root.WithOpName("mul_op3"), add_op, c_op); + + Output final_mul_op = + Mul(root.WithOpName("final_mul_op"), mul_op2, mul_op3); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + GraphDef merged_graph_def; + TestTransformedVersusFloatGraph(MergeDuplicateNodes, float_graph_def, + {{"a_op", a_tensor}}, {{"a_op", a_tensor}}, + {"mul_op1", "final_mul_op"}, {}, 1.0, + &merged_graph_def); + + std::map node_map; + MapNamesToNodes(merged_graph_def, &node_map); + EXPECT_EQ(1, node_map.count("a_op")); + EXPECT_EQ(1, node_map.count("b_op")); + EXPECT_EQ(1, node_map.count("a_relu_op")); + EXPECT_EQ(1, node_map.count("b_relu_op")); + EXPECT_EQ(1, node_map.count("mul_op1")); + EXPECT_EQ(1, node_map.count("final_mul_op")); + EXPECT_EQ(1, (node_map.count("mul_op2") + node_map.count("mul_op3"))); + string remaining_mul; + if (node_map.count("mul_op2")) { + remaining_mul = "mul_op2"; + } else { + remaining_mul = "mul_op3"; + } + EXPECT_EQ(remaining_mul, node_map["final_mul_op"]->input(0)); + EXPECT_EQ(remaining_mul, node_map["final_mul_op"]->input(1)); + EXPECT_EQ(1, node_map.count("c_op")); + EXPECT_EQ("add_op", node_map["mul_op1"]->input(0)); + EXPECT_EQ("c_op", node_map["mul_op1"]->input(1)); + } +}; + +TEST_F(QuantizeNodesTest, TestQuantizeMatMulTiny) { TestQuantizeMatMulTiny(); } + +TEST_F(QuantizeNodesTest, TestQuantizeMatMulSmall) { + TestQuantizeMatMulSmall(); +} + +TEST_F(QuantizeNodesTest, TestOddPaddingProblem) { + // Tests one error case we ran into in a real graph. + TestQuantizeConv2D(1, 4, 4, 1, 3, 1, 2, "SAME", + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); +} + +TEST_F(QuantizeNodesTest, TestQuantizeConv2D) { + TestQuantizeConv2D(1, 4, 3, 1, 3, 1, 1, "SAME", + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + {1, 4, 7, 2, 5, 8, 3, 6, 9}); +} + +TEST_F(QuantizeNodesTest, TestQuantizeBiasAdd) { TestQuantizeBiasAdd(); } + +TEST_F(QuantizeNodesTest, TestQuantizeConcat) { TestQuantizeConcat(); } + +TEST_F(QuantizeNodesTest, TestQuantizeRelu) { TestQuantizeRelu(); } + +TEST_F(QuantizeNodesTest, TestQuantizeRelu6) { TestQuantizeRelu6(); } + +TEST_F(QuantizeNodesTest, TestQuantizeMaxPool) { TestQuantizeMaxPool(); } + +TEST_F(QuantizeNodesTest, TestQuantizeAvgPool) { TestQuantizeAvgPool(); } + +TEST_F(QuantizeNodesTest, TestQuantizeReshape) { TestQuantizeReshape(); } + +TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantization) { + TestRemoveRedundantQuantization(); +} + +TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantizationWithBiasAdd) { + TestRemoveRedundantQuantizationWithBiasAdd(); +} + +TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantizationWithMultipleOutputs) { + TestRemoveRedundantQuantizationWithMultipleOutputs(); +} + +TEST_F(QuantizeNodesTest, TestQuantizePlaceholders) { + TestQuantizePlaceholders(); +} + +TEST_F(QuantizeNodesTest, TestInputRange) { TestInputRange(); } + +TEST_F(QuantizeNodesTest, TestFallbackRange) { TestFallbackRange(); } + +TEST_F(QuantizeNodesTest, TestConvertFakeQuantsToRequantize) { + TestConvertFakeQuantsToRequantize(); +} + +TEST_F(QuantizeNodesTest, TestMergeAdjacentRequantizes) { + TestMergeAdjacentRequantizes(); +} + +TEST_F(QuantizeNodesTest, TestConvertFakeQuantsEndToEnd) { + TestConvertFakeQuantsEndToEnd(); +} + +TEST_F(QuantizeNodesTest, TestHoistFakeQuants) { TestHoistFakeQuants(); } + +TEST_F(QuantizeNodesTest, TestMergeDuplicateQuantizes) { + TestMergeDuplicateQuantizes(); +} + +TEST_F(QuantizeNodesTest, TestMergeDuplicateConsts) { + TestMergeDuplicateConsts(); +} + +TEST_F(QuantizeNodesTest, TestMergeDuplicatesNested) { + TestMergeDuplicatesNested(); +} + +TEST_F(QuantizeNodesTest, TestMergeDuplicateInOut) { + TestMergeDuplicatesInOut(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_weights.cc b/tensorflow/tools/graph_transforms/quantize_weights.cc new file mode 100644 index 00000000000..e6f1498224f --- /dev/null +++ b/tensorflow/tools/graph_transforms/quantize_weights.cc @@ -0,0 +1,139 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Converts any large float constants into eight-bit equivalents, with a +// Dequantize op so that subsequent nodes can still access the results in a +// float form. +Status QuantizeWeights(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, {"Const"}, + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& old_const_node = match.node; + if (!old_const_node.attr().count("dtype")) { + return errors::InvalidArgument("No 'dtype' attribute for Const node ", + old_const_node.name()); + } + if (!old_const_node.attr().count("value")) { + return errors::InvalidArgument("No 'value' attribute for Const node ", + old_const_node.name()); + } + const DataType old_dtype = old_const_node.attr().at("dtype").type(); + Tensor old_tensor; + if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) { + return errors::InvalidArgument("Decoding Tensor failed for node", + old_const_node.name()); + } + const size_t num_elements = old_tensor.NumElements(); + // If this isn't a float constant, or it's too small, then reuse the + // same node with no changes. + if ((old_dtype != DT_FLOAT) || (num_elements < 16)) { + new_nodes->push_back(old_const_node); + return Status::OK(); + } + const float* old_values = old_tensor.flat().data(); + float min = std::numeric_limits::max(); + float max = std::numeric_limits::min(); + for (int i = 0; i < num_elements; ++i) { + const float value = old_values[i]; + min = std::min(min, value); + max = std::max(max, value); + } + // min_value == max_value is a tricky case. It can occur for general + // tensors, and of course for scalars. The quantized ops cannot deal + // with this case, so we set max_value to something else. + // It's a tricky question what is the numerically best solution to + // deal with this degeneracy. + // TODO(petewarden): Better use a tolerance than a hard comparison? + if (min == max) { + if (std::abs(min) < 0.000001f) { + max = min + 1.0f; + } else if (min > 0) { + max = 2.0f * min; + } else { + max = min / 2.0f; + } + } + Tensor quantized_tensor(DT_QUINT8, old_tensor.shape()); + FloatTensorToQuantizedInPlace(old_tensor, min, max, + &quantized_tensor); + + NodeDef quantized_const_node; + quantized_const_node.set_op("Const"); + quantized_const_node.set_name(old_const_node.name() + + "_quantized_const"); + SetNodeAttr("dtype", DT_QUINT8, &quantized_const_node); + SetNodeTensorAttr("value", quantized_tensor, + &quantized_const_node); + new_nodes->push_back(quantized_const_node); + + NodeDef min_node; + min_node.set_op("Const"); + min_node.set_name(old_const_node.name() + "_quantized_min"); + SetNodeAttr("dtype", DT_FLOAT, &min_node); + Tensor min_tensor(DT_FLOAT, {}); + min_tensor.scalar()() = min; + SetNodeTensorAttr("value", min_tensor, &min_node); + new_nodes->push_back(min_node); + + NodeDef max_node; + max_node.set_op("Const"); + max_node.set_name(old_const_node.name() + "_quantized_max"); + SetNodeAttr("dtype", DT_FLOAT, &max_node); + Tensor max_tensor(DT_FLOAT, {}); + max_tensor.scalar()() = max; + SetNodeTensorAttr("value", max_tensor, &max_node); + new_nodes->push_back(max_node); + + NodeDef dequantize_node; + dequantize_node.set_op("Dequantize"); + dequantize_node.set_name(old_const_node.name()); + SetNodeAttr("T", DT_QUINT8, &dequantize_node); + SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); + AddNodeInput(quantized_const_node.name(), &dequantize_node); + AddNodeInput(min_node.name(), &dequantize_node); + AddNodeInput(max_node.name(), &dequantize_node); + new_nodes->push_back(dequantize_node); + + return Status::OK(); + }, + {}, output_graph_def)); + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("quantize_weights", QuantizeWeights); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc new file mode 100644 index 00000000000..cd5feed3580 --- /dev/null +++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status QuantizeWeights(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class QuantizeWeightsTest : public ::testing::Test { + protected: + void TestQuantizeWeights() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10})); + test::FillValues( + &weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, + 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, + 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, + 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef quantized_graph_def; + TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}}, + &quantized_graph_def)); + + std::unique_ptr quantized_session(NewSession(SessionOptions())); + TF_ASSERT_OK(quantized_session->Create(quantized_graph_def)); + std::vector quantized_outputs; + TF_ASSERT_OK( + quantized_session->Run({}, {"output"}, {}, &quantized_outputs)); + + test::ExpectTensorNear(original_outputs[0], quantized_outputs[0], + 0.5); + + std::map node_lookup; + MapNamesToNodes(quantized_graph_def, &node_lookup); + EXPECT_EQ(1, node_lookup.count("input_op")); + const NodeDef* q_input_op = node_lookup.at("input_op"); + EXPECT_EQ(DT_FLOAT, q_input_op->attr().at("dtype").type()); + EXPECT_EQ(1, node_lookup.count("weights_op")); + const NodeDef* q_weights_op = node_lookup.at("weights_op"); + EXPECT_EQ("Dequantize", q_weights_op->op()); + const string& weights_const_name = + NodeNameFromInput(q_weights_op->input(0)); + EXPECT_EQ(1, node_lookup.count(weights_const_name)); + const NodeDef* q_weights_const = node_lookup.at(weights_const_name); + EXPECT_EQ("Const", q_weights_const->op()); + EXPECT_EQ(DT_QUINT8, q_weights_const->attr().at("dtype").type()); + } +}; + +TEST_F(QuantizeWeightsTest, TestQuantizeWeights) { TestQuantizeWeights(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_attribute.cc b/tensorflow/tools/graph_transforms/remove_attribute.cc new file mode 100644 index 00000000000..dd7ec8a0c63 --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_attribute.cc @@ -0,0 +1,70 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Deletes a given attribute from the specified nodes. +Status RemoveAttribute(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + if (!context.params.count("attribute_name") || + (context.params.at("attribute_name").size() != 1)) { + return errors::InvalidArgument( + "remove_nodes expects exactly one 'attribute_name' " + "argument, e.g. remove_attribute(op_name=Mul, attribute_name=foo)"); + } + + string op_name; + if (context.params.count("op_name")) { + if (context.params.at("op_name").size() != 1) { + return errors::InvalidArgument( + "remove_nodes expects a single op_name argument, but found ", + context.params.at("op_name").size()); + } + op_name = context.params.at("op_name")[0]; + } else { + op_name = "*"; + } + + const string attribute_name = context.params.at("attribute_name")[0]; + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + if (((op_name == "*") || (op_name == node.op())) && + (node.attr().count(attribute_name))) { + new_node->mutable_attr()->erase(attribute_name); + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("remove_attribute", RemoveAttribute); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_attribute_test.cc b/tensorflow/tools/graph_transforms/remove_attribute_test.cc new file mode 100644 index 00000000000..77a69864b0f --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_attribute_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RemoveAttribute(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RemoveAttributeTest : public ::testing::Test { + protected: + void TestRemoveAttribute() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + SetNodeAttr("foo", 23, mul_node1); + SetNodeAttr("bar", "something", mul_node1); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + SetNodeAttr("foo", 46, add_node2); + SetNodeAttr("bob", 23, add_node2); + SetNodeAttr("bar", "something else", add_node2); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef wildcard_result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + context.params.insert( + std::pair>({"op_name", {string("*")}})); + context.params.insert(std::pair>( + {"attribute_name", {string("foo")}})); + TF_ASSERT_OK(RemoveAttribute(graph_def, context, &wildcard_result)); + + std::map node_lookup; + MapNamesToNodes(wildcard_result, &node_lookup); + EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar")); + EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob")); + + GraphDef targeted_result; + TransformFuncContext targeted_context; + targeted_context.input_names = {}; + targeted_context.output_names = {"mul_node1"}; + targeted_context.params.insert( + std::pair>({"op_name", {string("Mul")}})); + targeted_context.params.insert(std::pair>( + {"attribute_name", {string("foo")}})); + TF_ASSERT_OK( + RemoveAttribute(graph_def, targeted_context, &targeted_result)); + + MapNamesToNodes(targeted_result, &node_lookup); + EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob")); + } +}; + +TEST_F(RemoveAttributeTest, TestRemoveAttribute) { TestRemoveAttribute(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_device.cc b/tensorflow/tools/graph_transforms/remove_device.cc new file mode 100644 index 00000000000..7f50dd60405 --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_device.cc @@ -0,0 +1,47 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Clears the device field of all ops in the graph. +Status RemoveDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + new_node->set_device(""); + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("remove_device", RemoveDevice); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_device_test.cc b/tensorflow/tools/graph_transforms/remove_device_test.cc new file mode 100644 index 00000000000..554c5e35952 --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_device_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RemoveDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RemoveDeviceTest : public ::testing::Test { + protected: + void TestRemoveDevice() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->set_device("//cpu:0"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + add_node2->set_device("//gpu:1"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + TF_ASSERT_OK(RemoveDevice(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ("", node_lookup.at("mul_node1")->device()); + EXPECT_EQ("", node_lookup.at("add_node2")->device()); + } +}; + +TEST_F(RemoveDeviceTest, TestRemoveDevice) { TestRemoveDevice(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc new file mode 100644 index 00000000000..3290e65512a --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_nodes.cc @@ -0,0 +1,94 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Deletes any specified types of nodes, unless they're necessary for the +// graph's inputs or outputs. +Status RemoveNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + if (!context.params.count("op")) { + return errors::InvalidArgument( + "remove_nodes expects at least one 'op'" + "argument, e.g. remove_nodes(op=Identity)"); + } + + // Make sure we don't get rid of any nodes used as graph inputs or outputs. + std::set required_nodes; + for (const string& input : context.input_names) { + required_nodes.insert(NodeNameFromInput(input)); + } + for (const string& output : context.output_names) { + required_nodes.insert(NodeNameFromInput(output)); + } + + std::vector ops_to_remove = context.params.at("op"); + GraphDef current_graph_def = input_graph_def; + for (const string& op : ops_to_remove) { + // Keep looking for nodes to remove until there are no more changes. + bool any_nodes_removed; + do { + any_nodes_removed = false; + std::map inputs_to_rename; + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, {op, {{"*"}}}, + [&inputs_to_rename, &required_nodes, &any_nodes_removed]( + const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& replace_node = match.node; + // If this node is needed in the inputs or outputs don't replace it. + if (required_nodes.count(replace_node.name())) { + LOG(INFO) << "Skipping replacement for " << replace_node.name(); + CopyOriginalMatch(match, new_nodes); + return Status::OK(); + } + const NodeDef& input_node = match.inputs[0].node; + inputs_to_rename[replace_node.name()] = input_node.name(); + inputs_to_rename["^" + replace_node.name()] = + "^" + input_node.name(); + new_nodes->push_back(input_node); + any_nodes_removed = true; + return Status::OK(); + }, + {true}, &replaced_graph_def)); + // Make sure all references to removed nodes now point to their inputs. + RenameNodeInputs(replaced_graph_def, inputs_to_rename, + ¤t_graph_def); + } while (any_nodes_removed); + } + + *output_graph_def = current_graph_def; + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("remove_nodes", RemoveNodes); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_nodes_test.cc b/tensorflow/tools/graph_transforms/remove_nodes_test.cc new file mode 100644 index 00000000000..e87ea1daa6f --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_nodes_test.cc @@ -0,0 +1,222 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RemoveNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RemoveNodesTest : public ::testing::Test { + protected: + void TestRemoveNodes() { + GraphDef graph_def; + + NodeDef* add_node1 = graph_def.add_node(); + add_node1->set_name("add_node1"); + add_node1->set_op("Add"); + add_node1->add_input("add_node2"); + add_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("identity_node1"); + add_node2->add_input("identity_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("identity_node1"); + add_node3->add_input("const_node3"); + + NodeDef* identity_node1 = graph_def.add_node(); + identity_node1->set_name("identity_node1"); + identity_node1->set_op("Identity"); + identity_node1->add_input("const_node1"); + + NodeDef* identity_node2 = graph_def.add_node(); + identity_node2->set_name("identity_node2"); + identity_node2->set_op("Identity"); + identity_node2->add_input("const_node2"); + + NodeDef* identity_node3 = graph_def.add_node(); + identity_node3->set_name("identity_node3"); + identity_node3->set_op("Identity"); + identity_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"add_node1"}; + context.params.insert( + std::pair>({"op", {string("Identity")}})); + TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node1")); + EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0)); + EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1)); + EXPECT_EQ(1, node_lookup.count("add_node2")); + EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0)); + EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1)); + EXPECT_EQ(1, node_lookup.count("add_node3")); + EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0)); + EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1)); + EXPECT_EQ(1, node_lookup.count("add_node4")); + EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0)); + EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1)); + EXPECT_EQ(0, node_lookup.count("identity_node1")); + EXPECT_EQ(0, node_lookup.count("identity_node2")); + EXPECT_EQ(0, node_lookup.count("identity_node3")); + EXPECT_EQ(1, node_lookup.count("const_node1")); + EXPECT_EQ("Const", node_lookup.at("const_node1")->op()); + EXPECT_EQ(1, node_lookup.count("const_node2")); + EXPECT_EQ("Const", node_lookup.at("const_node2")->op()); + EXPECT_EQ(1, node_lookup.count("const_node3")); + EXPECT_EQ("Const", node_lookup.at("const_node3")->op()); + } + + void TestRemoveOutputNodes() { + GraphDef graph_def; + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("const_node1"); + add_node->add_input("const_node2"); + + NodeDef* identity_node = graph_def.add_node(); + identity_node->set_name("identity_node"); + identity_node->set_op("Identity"); + identity_node->add_input("add_node"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"identity_node"}; + context.params.insert( + std::pair>({"op", {string("Identity")}})); + TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); + EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1)); + EXPECT_EQ(1, node_lookup.count("identity_node")); + EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0)); + } + + void TestRemoveChainedNodes() { + GraphDef graph_def; + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* identity_node1 = graph_def.add_node(); + identity_node1->set_name("identity_node1"); + identity_node1->set_op("Identity"); + identity_node1->add_input("const_node1"); + + NodeDef* identity_node2 = graph_def.add_node(); + identity_node2->set_name("identity_node2"); + identity_node2->set_op("Identity"); + identity_node2->add_input("identity_node1"); + + NodeDef* identity_node3 = graph_def.add_node(); + identity_node3->set_name("identity_node3"); + identity_node3->set_op("Identity"); + identity_node3->add_input("identity_node2"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("identity_node3"); + add_node->add_input("const_node2"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"identity_node"}; + context.params.insert( + std::pair>({"op", {string("Identity")}})); + TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); + EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1)); + EXPECT_EQ(0, node_lookup.count("identity_node1")); + EXPECT_EQ(0, node_lookup.count("identity_node2")); + EXPECT_EQ(0, node_lookup.count("identity_node3")); + } +}; + +TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); } + +TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); } + +TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/rename_attribute.cc b/tensorflow/tools/graph_transforms/rename_attribute.cc new file mode 100644 index 00000000000..3493cc37ea2 --- /dev/null +++ b/tensorflow/tools/graph_transforms/rename_attribute.cc @@ -0,0 +1,70 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +Status RenameAttribute(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + if (!context.params.count("old_attribute_name") || + (context.params.at("old_attribute_name").size() != 1) || + !context.params.count("new_attribute_name") || + (context.params.at("new_attribute_name").size() != 1)) { + return errors::InvalidArgument( + "remove_nodes expects exactly one 'old_attribute_name' and one " + "'new_attribute_name' argument, e.g. " + "remove_attribute(old_attribute_name=foo, new_attribute_name=bar)"); + } + + string op_name; + if (context.params.count("op_name")) { + op_name = context.params.at("op_name")[0]; + } else { + op_name = "*"; + } + + const string old_attribute_name = context.params.at("old_attribute_name")[0]; + const string new_attribute_name = context.params.at("new_attribute_name")[0]; + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + if (((op_name == "*") || (op_name == node.op())) && + (node.attr().count(old_attribute_name))) { + AttrValue attribute_value = node.attr().at(old_attribute_name); + new_node->mutable_attr()->erase(old_attribute_name); + new_node->mutable_attr()->insert({new_attribute_name, attribute_value}); + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("rename_attribute", RenameAttribute); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/rename_attribute_test.cc b/tensorflow/tools/graph_transforms/rename_attribute_test.cc new file mode 100644 index 00000000000..a0a33e9fc09 --- /dev/null +++ b/tensorflow/tools/graph_transforms/rename_attribute_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RenameAttribute(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RenameAttributeTest : public ::testing::Test { + protected: + void TestRenameAttribute() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + AddNodeAttr("foo", 23, mul_node1); + AddNodeAttr("bar", "something", mul_node1); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + AddNodeAttr("foo", 46, add_node2); + AddNodeAttr("bob", 23, add_node2); + AddNodeAttr("bar", "something else", add_node2); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef wildcard_result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + context.params.insert( + std::pair>({"op_name", {string("*")}})); + context.params.insert(std::pair>( + {"old_attribute_name", {string("foo")}})); + context.params.insert(std::pair>( + {"new_attribute_name", {string("baz")}})); + TF_ASSERT_OK(RenameAttribute(graph_def, context, &wildcard_result)); + + std::map node_lookup; + MapNamesToNodes(wildcard_result, &node_lookup); + EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz")); + EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar")); + EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("baz")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob")); + + GraphDef targeted_result; + TransformFuncContext targeted_context; + targeted_context.input_names = {}; + targeted_context.output_names = {"mul_node1"}; + targeted_context.params.insert( + std::pair>({"op_name", {string("Mul")}})); + targeted_context.params.insert(std::pair>( + {"old_attribute_name", {string("foo")}})); + targeted_context.params.insert(std::pair>( + {"new_attribute_name", {string("baz")}})); + TF_ASSERT_OK( + RenameAttribute(graph_def, targeted_context, &targeted_result)); + + MapNamesToNodes(targeted_result, &node_lookup); + EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo")); + EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz")); + EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo")); + EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("baz")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar")); + EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob")); + } +}; + +TEST_F(RenameAttributeTest, TestRenameAttribute) { TestRenameAttribute(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/rename_op.cc b/tensorflow/tools/graph_transforms/rename_op.cc new file mode 100644 index 00000000000..04441d028ff --- /dev/null +++ b/tensorflow/tools/graph_transforms/rename_op.cc @@ -0,0 +1,60 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Changes the op type of a specified op. +Status RenameOp(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + if (!context.params.count("old_op_name") || + (context.params.at("old_op_name").size() != 1) || + !context.params.count("new_op_name") || + (context.params.at("new_op_name").size() != 1)) { + return errors::InvalidArgument( + "remove_nodes expects exactly one 'old_op_name' and 'new_op_name' " + "argument, e.g. rename_op(old_op_name=Mul, new_op_name=Multiply)"); + } + + const string old_op_name = context.params.at("old_op_name")[0]; + const string new_op_name = context.params.at("new_op_name")[0]; + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + if (node.op() == old_op_name) { + new_node->set_op(new_op_name); + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("rename_op", RenameOp); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/rename_op_test.cc b/tensorflow/tools/graph_transforms/rename_op_test.cc new file mode 100644 index 00000000000..d09f2abaa9e --- /dev/null +++ b/tensorflow/tools/graph_transforms/rename_op_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RenameOp(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RenameOpTest : public ::testing::Test { + protected: + void TestRenameOp() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + context.params.insert(std::pair>( + {"old_op_name", {string("Mul")}})); + context.params.insert(std::pair>( + {"new_op_name", {string("Multiply")}})); + TF_ASSERT_OK(RenameOp(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("mul_node1")); + EXPECT_EQ("Multiply", node_lookup.at("mul_node1")->op()); + EXPECT_EQ(1, node_lookup.count("add_node2")); + EXPECT_EQ("Add", node_lookup.at("add_node2")->op()); + EXPECT_EQ(1, node_lookup.count("add_node3")); + EXPECT_EQ("Add", node_lookup.at("add_node3")->op()); + EXPECT_EQ(1, node_lookup.count("add_node4")); + EXPECT_EQ("Add", node_lookup.at("add_node4")->op()); + EXPECT_EQ(1, node_lookup.count("const_node1")); + EXPECT_EQ("Const", node_lookup.at("const_node1")->op()); + EXPECT_EQ(1, node_lookup.count("const_node2")); + EXPECT_EQ("Const", node_lookup.at("const_node2")->op()); + EXPECT_EQ(1, node_lookup.count("const_node3")); + EXPECT_EQ("Const", node_lookup.at("const_node3")->op()); + } +}; + +TEST_F(RenameOpTest, TestRenameOp) { TestRenameOp(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/round_weights.cc b/tensorflow/tools/graph_transforms/round_weights.cc new file mode 100644 index 00000000000..e73aae0f393 --- /dev/null +++ b/tensorflow/tools/graph_transforms/round_weights.cc @@ -0,0 +1,123 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Rounds any large float constants to the specified number of levels. +Status RoundWeights(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + string num_steps_string; + TF_RETURN_IF_ERROR( + GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string)); + int32 num_steps; + if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) { + return errors::InvalidArgument( + "Couldn't interpret the num_steps argument to round_weights as a " + "number:", + num_steps_string); + } + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + input_graph_def, {"Const"}, + [num_steps](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& old_const_node = match.node; + if (!old_const_node.attr().count("dtype")) { + return errors::InvalidArgument("No 'dtype' attribute for Const node ", + old_const_node.name()); + } + if (!old_const_node.attr().count("value")) { + return errors::InvalidArgument("No 'value' attribute for Const node ", + old_const_node.name()); + } + const DataType old_dtype = old_const_node.attr().at("dtype").type(); + Tensor old_tensor; + if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) { + return errors::InvalidArgument("Decoding Tensor failed for node", + old_const_node.name()); + } + const size_t num_elements = old_tensor.NumElements(); + // If this isn't a float constant, or it's too small, then reuse the + // same node with no changes. The size is important because small + // constants tend to be used for more accuracy-sensitive calculations, + // and the benefit of shrinking them is very marginal. + if ((old_dtype != DT_FLOAT) || (num_elements < 16)) { + new_nodes->push_back(old_const_node); + return Status::OK(); + } + const float* old_values = old_tensor.flat().data(); + float min = std::numeric_limits::max(); + float max = std::numeric_limits::min(); + for (int i = 0; i < num_elements; ++i) { + const float value = old_values[i]; + min = std::min(min, value); + max = std::max(max, value); + } + // min_value == max_value is a tricky case. It can occur for general + // tensors, and of course for scalars. The quantized ops cannot deal + // with this case, so we set max_value to something else. + // It's a tricky question what is the numerically best solution to + // deal with this degeneracy. + // TODO(petewarden): Better use a tolerance than a hard comparison? + if (min == max) { + if (std::abs(min) < 0.000001f) { + max = min + 1.0f; + } else if (min > 0) { + max = 2.0f * min; + } else { + min = 2.0f * max; + } + } + Tensor rounded_tensor(DT_FLOAT, old_tensor.shape()); + float* rounded_values = rounded_tensor.flat().data(); + const float bucket_width = (max - min) / num_steps; + for (int i = 0; i < num_elements; ++i) { + const int32 bucket = std::floor((old_values[i] - min) / bucket_width); + rounded_values[i] = min + (bucket_width * (bucket + 0.5f)); + } + + NodeDef rounded_const_node; + rounded_const_node.set_op("Const"); + rounded_const_node.set_name(old_const_node.name()); + SetNodeAttr("dtype", DT_FLOAT, &rounded_const_node); + SetNodeTensorAttr("value", rounded_tensor, &rounded_const_node); + new_nodes->push_back(rounded_const_node); + + return Status::OK(); + }, + {}, output_graph_def)); + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("round_weights", RoundWeights); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/round_weights_test.cc b/tensorflow/tools/graph_transforms/round_weights_test.cc new file mode 100644 index 00000000000..74700a2760c --- /dev/null +++ b/tensorflow/tools/graph_transforms/round_weights_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RoundWeights(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class RoundWeightsTest : public ::testing::Test { + protected: + void TestRoundWeights() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues( + &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f, + -5.0f, -3.0f, -6.0f}); + Output input_op = + Const(root.WithOpName("input_op"), Input::Initializer(input_data)); + + Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10})); + test::FillValues( + &weights_data, + {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, + 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, + 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, + 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f}); + Output weights_op = + Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); + + Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op, + {1, 1, 1, 1}, "VALID"); + + GraphDef original_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); + + std::unique_ptr original_session(NewSession(SessionOptions())); + TF_ASSERT_OK(original_session->Create(original_graph_def)); + std::vector original_outputs; + TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs)); + + GraphDef rounded_graph_def; + TF_ASSERT_OK( + RoundWeights(original_graph_def, {{}, {"output"}}, &rounded_graph_def)); + + std::unique_ptr rounded_session(NewSession(SessionOptions())); + TF_ASSERT_OK(rounded_session->Create(rounded_graph_def)); + std::vector rounded_outputs; + TF_ASSERT_OK(rounded_session->Run({}, {"output"}, {}, &rounded_outputs)); + + test::ExpectTensorNear(original_outputs[0], rounded_outputs[0], 0.5); + + std::map node_lookup; + MapNamesToNodes(rounded_graph_def, &node_lookup); + EXPECT_EQ(1, node_lookup.count("input_op")); + const NodeDef* r_input_op = node_lookup.at("input_op"); + EXPECT_EQ(DT_FLOAT, r_input_op->attr().at("dtype").type()); + EXPECT_EQ(1, node_lookup.count("weights_op")); + const NodeDef* r_weights_op = node_lookup.at("weights_op"); + EXPECT_EQ("Const", r_weights_op->op()); + EXPECT_EQ(DT_FLOAT, r_weights_op->attr().at("dtype").type()); + } +}; + +TEST_F(RoundWeightsTest, TestRoundWeights) { TestRoundWeights(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/sort_by_execution_order.cc b/tensorflow/tools/graph_transforms/sort_by_execution_order.cc new file mode 100644 index 00000000000..43152d20fcc --- /dev/null +++ b/tensorflow/tools/graph_transforms/sort_by_execution_order.cc @@ -0,0 +1,43 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// This is a thin wrapper with the standard TransformFunc interface to the +// underlying utility function. The only difference is that we don't use the +// input or output name arguments. +Status SortByExecutionOrderWithUnusedContext( + const GraphDef& input_graph_def, const TransformFuncContext& unused_context, + GraphDef* output_graph_def) { + return SortByExecutionOrder(input_graph_def, output_graph_def); +} + +REGISTER_GRAPH_TRANSFORM("sort_by_execution_order", + SortByExecutionOrderWithUnusedContext); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/sort_by_execution_order_test.cc b/tensorflow/tools/graph_transforms/sort_by_execution_order_test.cc new file mode 100644 index 00000000000..a995f390092 --- /dev/null +++ b/tensorflow/tools/graph_transforms/sort_by_execution_order_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +class SortByExecutionOrderTest : public ::testing::Test { + protected: + void GetOrder(const GraphDef& graph_def, std::map* order) { + for (int i = 0; i < graph_def.node_size(); ++i) { + const NodeDef& node = graph_def.node(i); + (*order)[node.name()] = i; + } + } + + void TestSimpleAdd() { + GraphDef graph_def; + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("a_node"); + add_node->add_input("b_node"); + + NodeDef* b_node = graph_def.add_node(); + b_node->set_name("b_node"); + b_node->set_op("Const"); + + NodeDef* a_node = graph_def.add_node(); + a_node->set_name("a_node"); + a_node->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); + + std::map order; + GetOrder(result, &order); + EXPECT_EQ(2, order["add_node"]); + EXPECT_GT(2, order["a_node"]); + EXPECT_GT(2, order["b_node"]); + } + + void TestSimpleLinear() { + GraphDef graph_def; + + NodeDef* negative_node = graph_def.add_node(); + negative_node->set_name("negative_node"); + negative_node->set_op("Negative"); + negative_node->add_input("sqrt_node"); + + NodeDef* relu_node = graph_def.add_node(); + relu_node->set_name("relu_node"); + relu_node->set_op("Relu"); + relu_node->add_input("const_node"); + + NodeDef* sqrt_node = graph_def.add_node(); + sqrt_node->set_name("sqrt_node"); + sqrt_node->set_op("Sqrt"); + sqrt_node->add_input("relu_node"); + + NodeDef* const_node = graph_def.add_node(); + const_node->set_name("const_node"); + const_node->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); + + std::map order; + GetOrder(result, &order); + EXPECT_EQ(3, order["negative_node"]); + EXPECT_EQ(2, order["sqrt_node"]); + EXPECT_EQ(1, order["relu_node"]); + EXPECT_EQ(0, order["const_node"]); + } + + void TestSimpleTree() { + GraphDef graph_def; + + NodeDef* add_node1 = graph_def.add_node(); + add_node1->set_name("add_node1"); + add_node1->set_op("Add"); + add_node1->add_input("add_node2"); + add_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node3"); + add_node3->add_input("const_node4"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* const_node4 = graph_def.add_node(); + const_node4->set_name("const_node4"); + const_node4->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); + + std::map order; + GetOrder(result, &order); + EXPECT_EQ(6, order["add_node1"]); + EXPECT_GT(6, order["add_node2"]); + EXPECT_GT(6, order["add_node3"]); + EXPECT_GT(5, order["const_node1"]); + EXPECT_GT(5, order["const_node2"]); + EXPECT_GT(5, order["const_node3"]); + EXPECT_GT(5, order["const_node4"]); + } + + void TestCommonAncestor() { + GraphDef graph_def; + + NodeDef* add_node1 = graph_def.add_node(); + add_node1->set_name("add_node1"); + add_node1->set_op("Add"); + add_node1->add_input("add_node2"); + add_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result)); + + std::map order; + GetOrder(result, &order); + EXPECT_EQ(5, order["add_node1"]); + EXPECT_GT(5, order["add_node2"]); + EXPECT_GT(5, order["add_node3"]); + EXPECT_GT(4, order["const_node2"]); + EXPECT_GT(4, order["const_node3"]); + EXPECT_GT(3, order["const_node1"]); + } +}; + +TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); } + +TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); } + +TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); } + +TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc new file mode 100644 index 00000000000..786bf4f6da1 --- /dev/null +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc @@ -0,0 +1,215 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/fold_constants_lib.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +namespace { + +Status TypeForPlaceholder(const TransformFuncContext& context, + const string& node_name, DataType* result) { + // If we don't find anything else, return float. + *result = DT_FLOAT; + + // Check to see if we have been given a default for all placeholders. + if (context.params.count("type")) { + if (context.params.at("type").size() != 1) { + return errors::InvalidArgument( + "You must pass no more than one default 'type' to " + "strip_unused_nodes"); + } + const string& type_string = context.params.at("type")[0]; + if (!DataTypeFromString(type_string, result)) { + return errors::InvalidArgument("Couldn't understand type argument '", + type_string, "'"); + } + } + + // See if there's a particular type specified for this placeholder. + if (context.params.count("name") || context.params.count("type_for_name")) { + if (!context.params.count("name") || + !context.params.count("type_for_name") || + (context.params.at("type_for_name").size() != + context.params.at("name").size())) { + return errors::InvalidArgument( + "You must pass a 'type_for_name' arg for every 'name', e.g. " + "strip_unused_nodes(name=foo, type_for_name=float, name=bar, " + "type_for_name=quint8"); + } + const int name_count = context.params.at("name").size(); + for (int i = 0; i < name_count; ++i) { + if (context.params.at("name")[i] == node_name) { + const string& type_string = context.params.at("type_for_name")[i]; + if (!DataTypeFromString(type_string, result)) { + return errors::InvalidArgument("Couldn't understand type argument '", + type_string, "'"); + } + } + } + } + + return Status::OK(); +} + +// Takes a comma-separated string of numbers and parses them into a shape. +bool TensorShapeFromString(const string& shape_string, TensorShape* result) { + if (shape_string == "") { + return false; + } + std::vector dims; + if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) { + return false; + } + *result = TensorShape(dims); + return true; +} + +Status ShapeForPlaceholder(const TransformFuncContext& context, + const string& node_name, TensorShape* result) { + // If we don't find anything else, return scalar. + *result = {}; + + // Check to see if we have been given a default for all placeholders. + if (context.params.count("type")) { + if (context.params.at("shape").size() != 1) { + return errors::InvalidArgument( + "You must pass no more than one default 'shape' to " + "strip_unused_nodes"); + } + const string& shape_string = context.params.at("shape")[0]; + if (!TensorShapeFromString(shape_string, result)) { + return errors::InvalidArgument("Couldn't understand shape argument '", + shape_string, "'"); + } + } + + // See if there's a particular type specified for this placeholder. + if (context.params.count("name") || context.params.count("type_for_name")) { + if (!context.params.count("name") || + !context.params.count("type_for_name") || + (context.params.at("type_for_name").size() != + context.params.at("name").size())) { + return errors::InvalidArgument( + "You must pass a 'shape_for_name' arg for every 'name', e.g. " + "strip_unused_nodes(name=foo, shape_for_name=\"2,2,1\", name=bar, " + "shape_for_name=\"1\""); + } + const int name_count = context.params.at("name").size(); + for (int i = 0; i < name_count; ++i) { + if (context.params.at("name")[i] == node_name) { + const string& shape_string = context.params.at("shape_for_name")[i]; + if (!TensorShapeFromString(shape_string, result)) { + return errors::InvalidArgument("Couldn't understand shape argument '", + shape_string, "'"); + } + } + } + } + + return Status::OK(); +} +} // namespace + +// Delete any nodes that don't contribute to the inference result. +Status StripUnusedNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + std::set required_nodes; + std::set input_nodes; + for (const string& input : context.input_names) { + required_nodes.insert(NodeNameFromInput(input)); + input_nodes.insert(NodeNameFromInput(input)); + } + for (const string& output : context.output_names) { + required_nodes.insert(output); + } + + std::map node_lookup; + MapNamesToNodes(input_graph_def, &node_lookup); + + std::vector current_inputs; + for (const string& output_name : context.output_names) { + current_inputs.push_back(NodeNameFromInput(output_name)); + } + + while (!current_inputs.empty()) { + std::set next_inputs; + for (const string& current_input : current_inputs) { + required_nodes.insert(current_input); + if (input_nodes.count(current_input)) { + continue; + } + if (!node_lookup.count(current_input)) { + return errors::InvalidArgument("Input node ", current_input, + " not found in graph"); + } + const NodeDef* current_node = node_lookup[current_input]; + for (const string& input_name : current_node->input()) { + string input_node_name = NodeNameFromInput(input_name); + if (!required_nodes.count(input_node_name)) { + next_inputs.insert(input_node_name); + } + } + } + current_inputs = + std::vector(next_inputs.begin(), next_inputs.end()); + } + + GraphDef filtered_graph_def; + FilterGraphDef(input_graph_def, + [&](const NodeDef& node) { + return required_nodes.count(node.name()) > 0; + }, + &filtered_graph_def); + + output_graph_def->Clear(); + for (const NodeDef& node : filtered_graph_def.node()) { + if (input_nodes.count(node.name())) { + NodeDef placeholder_node; + if (node.op() == "Placeholder") { + placeholder_node.CopyFrom(node); + } else { + placeholder_node.set_op("Placeholder"); + placeholder_node.set_name(node.name()); + DataType type; + TF_RETURN_IF_ERROR(TypeForPlaceholder(context, node.name(), &type)); + TensorShape shape; + TF_RETURN_IF_ERROR(ShapeForPlaceholder(context, node.name(), &shape)); + SetNodeAttr("dtype", type, &placeholder_node); + SetNodeAttr("shape", shape, &placeholder_node); + } + *(output_graph_def->mutable_node()->Add()) = placeholder_node; + } else { + *(output_graph_def->mutable_node()->Add()) = node; + } + } + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("strip_unused_nodes", StripUnusedNodes); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc new file mode 100644 index 00000000000..4eb074998f7 --- /dev/null +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc @@ -0,0 +1,286 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status StripUnusedNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class StripUnusedNodesTest : public ::testing::Test { + protected: + void TestSimpleAdd() { + GraphDef graph_def; + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("a_node"); + add_node->add_input("b_node"); + + NodeDef* a_node = graph_def.add_node(); + a_node->set_name("a_node"); + a_node->set_op("Const"); + + NodeDef* b_node = graph_def.add_node(); + b_node->set_name("b_node"); + b_node->set_op("Const"); + + NodeDef* c_node = graph_def.add_node(); + c_node->set_name("c_node"); + c_node->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(StripUnusedNodes(graph_def, {{}, {"add_node"}}, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ(1, node_lookup.count("a_node")); + EXPECT_EQ(1, node_lookup.count("b_node")); + EXPECT_EQ(0, node_lookup.count("c_node")); + } + + void TestCommonAncestor() { + GraphDef graph_def; + + NodeDef* add_node1 = graph_def.add_node(); + add_node1->set_name("add_node1"); + add_node1->set_op("Add"); + add_node1->add_input("add_node2"); + add_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* dangling_input = graph_def.add_node(); + dangling_input->set_name("dangling_input"); + dangling_input->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + GraphDef result; + TF_ASSERT_OK(StripUnusedNodes( + graph_def, {{"dangling_input"}, {"add_node1"}}, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node1")); + EXPECT_EQ(1, node_lookup.count("add_node2")); + EXPECT_EQ(1, node_lookup.count("add_node3")); + EXPECT_EQ(0, node_lookup.count("add_node4")); + EXPECT_EQ(1, node_lookup.count("const_node1")); + EXPECT_EQ(1, node_lookup.count("const_node2")); + EXPECT_EQ(1, node_lookup.count("const_node3")); + EXPECT_EQ(0, node_lookup.count("const_node4")); + EXPECT_EQ(1, node_lookup.count("dangling_input")); + } + + void TestSimplePlaceholder() { + GraphDef graph_def; + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("mul_node"); + add_node->add_input("a_node"); + + NodeDef* mul_node = graph_def.add_node(); + mul_node->set_name("mul_node"); + mul_node->set_op("Mul"); + mul_node->add_input("b_node"); + mul_node->add_input("c_node"); + + NodeDef* a_node = graph_def.add_node(); + a_node->set_name("a_node"); + a_node->set_op("Const"); + + NodeDef* b_node = graph_def.add_node(); + b_node->set_name("b_node"); + b_node->set_op("Const"); + + NodeDef* c_node = graph_def.add_node(); + c_node->set_name("c_node"); + c_node->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK( + StripUnusedNodes(graph_def, {{"mul_node"}, {"add_node"}}, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ(1, node_lookup.count("mul_node")); + EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op()); + EXPECT_EQ(DT_FLOAT, node_lookup["mul_node"]->attr().at("dtype").type()); + EXPECT_EQ(TensorShape({}), + TensorShape(node_lookup["mul_node"]->attr().at("shape").shape())); + EXPECT_EQ(1, node_lookup.count("a_node")); + EXPECT_EQ(0, node_lookup.count("b_node")); + EXPECT_EQ(0, node_lookup.count("c_node")); + } + + void TestPlaceholderDefaultArgs() { + GraphDef graph_def; + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("mul_node"); + add_node->add_input("a_node"); + + NodeDef* mul_node = graph_def.add_node(); + mul_node->set_name("mul_node"); + mul_node->set_op("Mul"); + mul_node->add_input("b_node"); + mul_node->add_input("c_node"); + + NodeDef* a_node = graph_def.add_node(); + a_node->set_name("a_node"); + a_node->set_op("Const"); + + NodeDef* b_node = graph_def.add_node(); + b_node->set_name("b_node"); + b_node->set_op("Const"); + + NodeDef* c_node = graph_def.add_node(); + c_node->set_name("c_node"); + c_node->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(StripUnusedNodes(graph_def, + {{"mul_node"}, + {"add_node"}, + {{"type", {"int32"}}, {"shape", {"1,2,3"}}}}, + &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ(1, node_lookup.count("mul_node")); + EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op()); + EXPECT_EQ(DT_INT32, node_lookup["mul_node"]->attr().at("dtype").type()); + EXPECT_EQ(TensorShape({1, 2, 3}), + TensorShape(node_lookup["mul_node"]->attr().at("shape").shape())); + EXPECT_EQ(1, node_lookup.count("a_node")); + EXPECT_EQ(0, node_lookup.count("b_node")); + EXPECT_EQ(0, node_lookup.count("c_node")); + } + + void TestPlaceholderNamedArgs() { + GraphDef graph_def; + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("mul_node"); + add_node->add_input("a_node"); + + NodeDef* mul_node = graph_def.add_node(); + mul_node->set_name("mul_node"); + mul_node->set_op("Mul"); + mul_node->add_input("b_node"); + mul_node->add_input("c_node"); + + NodeDef* a_node = graph_def.add_node(); + a_node->set_name("a_node"); + a_node->set_op("Const"); + + NodeDef* b_node = graph_def.add_node(); + b_node->set_name("b_node"); + b_node->set_op("Const"); + + NodeDef* c_node = graph_def.add_node(); + c_node->set_name("c_node"); + c_node->set_op("Const"); + + GraphDef result; + TF_ASSERT_OK(StripUnusedNodes(graph_def, + {{"mul_node", "a_node"}, + {"add_node"}, + {{"name", {"a_node", "mul_node"}}, + {"type_for_name", {"int64", "quint8"}}, + {"shape_for_name", {"1,2", "1, 2, 3"}}}}, + &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ(1, node_lookup.count("mul_node")); + EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op()); + EXPECT_EQ(DT_QUINT8, node_lookup["mul_node"]->attr().at("dtype").type()); + EXPECT_EQ(TensorShape({1, 2, 3}), + TensorShape(node_lookup["mul_node"]->attr().at("shape").shape())); + EXPECT_EQ(1, node_lookup.count("a_node")); + EXPECT_EQ("Placeholder", node_lookup["a_node"]->op()); + EXPECT_EQ(DT_INT64, node_lookup["a_node"]->attr().at("dtype").type()); + EXPECT_EQ(TensorShape({1, 2}), + TensorShape(node_lookup["a_node"]->attr().at("shape").shape())); + EXPECT_EQ(0, node_lookup.count("b_node")); + EXPECT_EQ(0, node_lookup.count("c_node")); + } +}; + +TEST_F(StripUnusedNodesTest, TestSimpleAdd) { TestSimpleAdd(); } + +TEST_F(StripUnusedNodesTest, TestCommonAncestor) { TestCommonAncestor(); } + +TEST_F(StripUnusedNodesTest, TestSimplePlaceholder) { TestSimplePlaceholder(); } + +TEST_F(StripUnusedNodesTest, TestPlaceholderDefaultArgs) { + TestPlaceholderDefaultArgs(); +} + +TEST_F(StripUnusedNodesTest, TestPlaceholderNamedArgs) { + TestPlaceholderNamedArgs(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc new file mode 100644 index 00000000000..638296b9231 --- /dev/null +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -0,0 +1,205 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This program prints out a summary of a GraphDef file's contents, listing +// things that are useful for debugging and reusing the model it contains. For +// example it looks at the graph structure and op types to figure out likely +// input and output nodes, and shows which ops are used by the graph. To use it, +// run something like this: +// +// bazel build tensorflow/tools/graph_transforms:summarize_graph +// bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ +// --in_graph=my_graph.pb + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { +namespace { + +Status SummarizeGraph(const GraphDef& graph) { + std::vector placeholders; + for (const NodeDef& node : graph.node()) { + if (node.op() == "Placeholder") { + placeholders.push_back(&node); + } + } + + if (placeholders.empty()) { + std::cout << "No inputs spotted." << std::endl; + } else { + std::cout << "Found " << placeholders.size() << " possible inputs: "; + for (const NodeDef* node : placeholders) { + TensorShape shape; + if (node->attr().count("shape")) { + TensorShapeProto shape_proto = node->attr().at("shape").shape(); + shape = TensorShape(shape_proto); + } + DataType dtype = node->attr().at("dtype").type(); + std::cout << "(name=" << node->name(); + std::cout << ", type=" << DataTypeString(dtype) << "(" << dtype << ")"; + std::cout << ", shape=" << shape.DebugString() << ") "; + } + std::cout << std::endl; + } + + std::map> output_map; + MapNodesToOutputs(graph, &output_map); + std::vector outputs; + for (const NodeDef& node : graph.node()) { + if (output_map.count(node.name()) == 0) { + outputs.push_back(&node); + } + } + + if (outputs.empty()) { + std::cout << "No outputs spotted." << std::endl; + } else { + std::cout << "Found " << outputs.size() << " possible outputs: "; + for (const NodeDef* node : outputs) { + std::cout << "(name=" << node->name(); + std::cout << ", op=" << node->op() << ") "; + } + std::cout << std::endl; + } + + int const_count = 0; + int variable_count = 0; + int identity_count = 0; + int control_edge_count = 0; + std::map device_counts; + for (const NodeDef& node : graph.node()) { + if (node.op() == "Const") { + ++const_count; + } else if (node.op() == "Variable") { + ++variable_count; + } else if (node.op() == "Identity") { + ++identity_count; + } + for (const string& input : node.input()) { + if (input.substr(0, 1) == "^") { + ++control_edge_count; + } + } + if (node.device() != "") { + ++device_counts[node.device()]; + } + } + + std::cout << "Found " << const_count << " consts, " << variable_count + << " variables, " << identity_count << " identities, and " + << control_edge_count << " control_edges" << std::endl; + if (!device_counts.empty()) { + for (const auto& device_info : device_counts) { + std::cout << device_info.second << " nodes assigned to device '" + << device_info.first << "'"; + } + } + + std::vector> invalid_inputs; + FindInvalidInputs(graph, &invalid_inputs); + if (!invalid_inputs.empty()) { + for (const std::pair& invalid_input : invalid_inputs) { + std::cout << "Invalid input " << invalid_input.second << " for node " + << invalid_input.first << std::endl; + } + return errors::Internal( + "Invalid graph with inputs referring to nonexistent nodes"); + } + + std::map op_counts; + for (const NodeDef& node : graph.node()) { + ++op_counts[node.op()]; + } + std::vector> op_counts_vec(op_counts.begin(), + op_counts.end()); + std::sort(op_counts_vec.begin(), op_counts_vec.end(), + [](std::pair a, std::pair b) { + return (a.second > b.second); + }); + std::cout << "Op types used: "; + bool is_first = true; + for (const std::pair& op_count : op_counts_vec) { + if (!is_first) { + std::cout << ", "; + } else { + is_first = false; + } + std::cout << op_count.second << " " << op_count.first; + } + std::cout << std::endl; + + return Status::OK(); +} + +int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) { + string in_graph = ""; + string out_graph = ""; + string inputs_string = ""; + string outputs_string = ""; + string transforms_string = ""; + std::vector flag_list = { + Flag("in_graph", &in_graph, "input graph file name"), + }; + string usage = Flags::Usage(argv[0], flag_list); + + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + // We need to call this to set up global state for TensorFlow. + port::InitMain(argv[0], &argc, &argv); + + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage; + return -1; + } + if (in_graph.empty()) { + LOG(ERROR) << "in_graph graph can't be empty.\n" << usage; + return -1; + } + + GraphDef graph_def; + Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def); + if (!load_status.ok()) { + LOG(ERROR) << "Loading graph '" << in_graph << "' failed with " + << load_status.error_message(); + LOG(ERROR) << usage; + return -1; + } + + Status summarize_result = SummarizeGraph(graph_def); + if (!summarize_result.ok()) { + LOG(ERROR) << summarize_result.error_message() << "\n" << usage; + return -1; + } + + return 0; +} + +} // namespace +} // namespace graph_transforms +} // namespace tensorflow + +int main(int argc, char* argv[]) { + return tensorflow::graph_transforms::ParseFlagsAndSummarizeGraph(argc, argv); +} diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc new file mode 100644 index 00000000000..5e71b0bd5cd --- /dev/null +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -0,0 +1,280 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/transform_graph.h" + +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +using tensorflow::strings::Scanner; + +Status ParseTransformParameters(const string& transforms_string, + TransformParameters* params_list) { + params_list->clear(); + enum { + TRANSFORM_NAME, + TRANSFORM_PARAM_NAME, + TRANSFORM_PARAM_VALUE, + } state = TRANSFORM_NAME; + StringPiece remaining(transforms_string); + StringPiece match; + StringPiece transform_name; + StringPiece parameter_name; + StringPiece parameter_value; + TransformFuncParameters func_parameters; + while (!remaining.empty()) { + if (state == TRANSFORM_NAME) { + // Reset the list of parameters. + func_parameters.clear(); + // Eat up any leading spaces. + Scanner(remaining).Any(Scanner::SPACE).GetResult(&remaining, &match); + // See if we have a valid transform name. + const bool found_transform_name = + Scanner(remaining) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .GetResult(&remaining, &transform_name); + if (!found_transform_name) { + return errors::InvalidArgument("Looking for transform name, but found ", + remaining.ToString().c_str()); + } + if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) { + state = TRANSFORM_PARAM_NAME; + } else { + // Add a transform with no parameters. + params_list->push_back({transform_name.ToString(), func_parameters}); + transform_name = ""; + state = TRANSFORM_NAME; + } + } else if (state == TRANSFORM_PARAM_NAME) { + if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) { + params_list->push_back({transform_name.ToString(), func_parameters}); + transform_name = ""; + state = TRANSFORM_NAME; + } else { + // Eat up any leading spaces or commas. + Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match); + Scanner(remaining).Any(Scanner::SPACE).GetResult(&remaining, &match); + // See if we have a valid parameter name. + const bool found_parameter_name = + Scanner(remaining) + .Any(Scanner::LETTER_DIGIT_UNDERSCORE) + .GetResult(&remaining, ¶meter_name); + if (!found_parameter_name) { + return errors::InvalidArgument( + "Looking for parameter name, but found ", + remaining.ToString().c_str()); + } + if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) { + state = TRANSFORM_PARAM_VALUE; + } else { + return errors::InvalidArgument("Looking for =, but found ", + remaining.ToString().c_str()); + } + } + } else if (state == TRANSFORM_PARAM_VALUE) { + bool found_parameter_value; + // Deal with quoted values. + if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) { + found_parameter_value = + Scanner(remaining).ScanEscapedUntil('"').GetResult( + &remaining, ¶meter_value); + if (found_parameter_value) { + Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match); + } + } else { + // See if we have a valid parameter name. + found_parameter_value = + Scanner(remaining) + .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .GetResult(&remaining, ¶meter_value); + } + if (!found_parameter_value) { + return errors::InvalidArgument("Looking for parameter name, but found ", + remaining.ToString().c_str()); + } + func_parameters[parameter_name.ToString()].push_back( + parameter_value.ToString()); + // Eat up any trailing quotes. + Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match); + Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match); + state = TRANSFORM_PARAM_NAME; + } + } + return Status::OK(); +} + +int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) { + string in_graph = ""; + string out_graph = ""; + string inputs_string = ""; + string outputs_string = ""; + string transforms_string = ""; + std::vector flag_list = { + Flag("in_graph", &in_graph, "input graph file name"), + Flag("out_graph", &out_graph, "output graph file name"), + Flag("inputs", &inputs_string, "inputs"), + Flag("outputs", &outputs_string, "outputs"), + Flag("transforms", &transforms_string, "list of transforms"), + }; + string usage = Flags::Usage(argv[0], flag_list); + usage += "\nTransforms are:\n"; + TransformRegistry* transform_registry = GetTransformRegistry(); + for (const auto& pair : *transform_registry) { + usage += pair.first + "\n"; + } + + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + // We need to call this to set up global state for TensorFlow. + if (init_main) { + port::InitMain(argv[0], &argc, &argv); + } + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage; + return -1; + } + if (in_graph.empty()) { + LOG(ERROR) << "in_graph graph can't be empty.\n" << usage; + return -1; + } + if (out_graph.empty()) { + LOG(ERROR) << "out_graph graph can't be empty.\n" << usage; + return -1; + } + if (transforms_string.empty()) { + LOG(ERROR) << "You must specify at least one transform.\n" << usage; + return -1; + } + + std::vector inputs = str_util::Split(inputs_string, ','); + std::vector outputs = str_util::Split(outputs_string, ','); + TransformParameters transform_params; + Status parse_status = + ParseTransformParameters(transforms_string, &transform_params); + if (!parse_status.ok()) { + LOG(ERROR) << "Failed to parse --transform argument, error was " + << parse_status.error_message(); + return -1; + } + if (transform_params.empty()) { + LOG(ERROR) << "You must specify at least one transform.\n" << usage; + return -1; + } + + GraphDef graph_def; + Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def); + if (!load_status.ok()) { + LOG(ERROR) << "Loading graph '" << in_graph << "' failed with " + << load_status.error_message(); + LOG(ERROR) << usage; + return -1; + } + + Status transform_result = + TransformGraph(inputs, outputs, transform_params, &graph_def); + + if (!transform_result.ok()) { + LOG(ERROR) << transform_result.error_message(); + LOG(ERROR) << usage; + return -1; + } + + Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def); + if (!save_status.ok()) { + LOG(ERROR) << "Saving graph '" << out_graph << "' failed with " + << save_status.error_message(); + return -1; + } + + return 0; +} + +Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, + bool* ignore_errors) { + *ignore_errors = false; + if (transform_params.count("ignore_errors") && + (!transform_params.at("ignore_errors").empty())) { + const string& ignore_errors_string = + str_util::Lowercase(transform_params.at("ignore_errors").at(0)); + if (ignore_errors_string == "true") { + *ignore_errors = true; + } else if (ignore_errors_string == "false") { + *ignore_errors = false; + } else { + return errors::InvalidArgument( + "ignore_errors should be true or false, found ", + ignore_errors_string); + } + } + return Status::OK(); +} + +Status TransformGraph(const std::vector& inputs, + const std::vector& outputs, + const TransformParameters& transform_params, + GraphDef* graph_def) { + TransformRegistry* transform_registry = GetTransformRegistry(); + for (const auto& transform_info : transform_params) { + const string& transform_name = transform_info.first; + if (transform_name == "") { + continue; + } + if (!transform_registry->count(transform_name)) { + return errors::InvalidArgument("Transform '", transform_name, + "' not recognized."); + } + LOG(INFO) << "Applying " << transform_name; + const TransformFunc& transform_func = + transform_registry->at(transform_name); + TransformFuncContext context; + context.input_names = inputs; + context.output_names = outputs; + context.params = transform_info.second; + bool ignore_errors; + TF_RETURN_IF_ERROR( + ShouldIgnoreErrors(transform_info.second, &ignore_errors)); + GraphDef transformed_graph_def; + Status transform_result = + transform_func(*graph_def, context, &transformed_graph_def); + if (!transform_result.ok()) { + if (ignore_errors) { + LOG(ERROR) << transform_name << ": Ignoring error " + << transform_result.error_message(); + transformed_graph_def = *graph_def; + } else { + return transform_result; + } + } + // Copy over the library from the original input graph. + transformed_graph_def.mutable_library()->CopyFrom(graph_def->library()); + TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def)); + + *graph_def = transformed_graph_def; + } + return Status::OK(); +} +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_graph.h b/tensorflow/tools/graph_transforms/transform_graph.h new file mode 100644 index 00000000000..58ec1419317 --- /dev/null +++ b/tensorflow/tools/graph_transforms/transform_graph.h @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_ +#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Convenience function to handle argument parsing for the command line tool. +// If init_main is false, we're testing so don't call core initialization. +int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main); + +// Handles converting the transforms string into transform names and their +// arguments. +typedef std::vector> + TransformParameters; +Status ParseTransformParameters(const string& transforms_string, + TransformParameters* params_list); + +// Applies a series of transformations to the GraphDef. These transforms are +// defined by modules that call REGISTER_GRAPH_TRANSFORM() to associate a +// function with a name string. +Status TransformGraph(const std::vector& inputs, + const std::vector& outputs, + const TransformParameters& transform_params, + GraphDef* graph_def); + +} // namespace graph_transforms +} // namespace tensorflow + +#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_ diff --git a/tensorflow/tools/graph_transforms/transform_graph_main.cc b/tensorflow/tools/graph_transforms/transform_graph_main.cc new file mode 100644 index 00000000000..1244c19c3ac --- /dev/null +++ b/tensorflow/tools/graph_transforms/transform_graph_main.cc @@ -0,0 +1,53 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tool that applies a series of transformations to a frozen GraphDef file. +// It takes a flexible list of transforms either on the command line, and runs +// those on the incoming graph to produce the result. This allows you to build a +// processing pipeline when preparing models for deployment. +// +// bazel build tensorflow/tools/graph_transforms/fold_constants_tool && +// bazel-bin/tensorflow/tools/graph_transforms/fold_constants_tool \ +// --in_graph=graph_def.pb \ +// --out_graph=transformed_graph_def.pb \ +// --inputs=input1,input2 \ +// --outputs=output1,output2 \ +// --transforms="fold_constants order_nodes" +// +// Parameters: +// in_graph - name of a file with a frozen GraphDef proto in binary format. +// out_graph - name of the output file to save the transformed version to. +// inputs - layer names of the nodes that will be fed data. +// outputs - layer names of the nodes that will be read from after running. +// transforms - space-separated names of the transforms to apply. +// +// List of implemented transforms: +// fold_constants - Merges constant expression subgraphs into single constants, +// which can help reduce the number of ops and make subsequent transforms +// optimizations more effective. +// order_nodes - Sorts the GraphDef nodes in execution order, which can help +// simple inference engines that want to avoid complexity in their executors. + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_graph.h" + +int main(int argc, char* argv[]) { + return tensorflow::graph_transforms::ParseFlagsAndTransformGraph(argc, argv, + true); +} diff --git a/tensorflow/tools/graph_transforms/transform_graph_test.cc b/tensorflow/tools/graph_transforms/transform_graph_test.cc new file mode 100644 index 00000000000..12df4051fbb --- /dev/null +++ b/tensorflow/tools/graph_transforms/transform_graph_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/transform_graph.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declared here so we don't have to expose it in the public header. +Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, + bool* ignore_errors); + +namespace { +Status test_empty_graph_transform(const GraphDef& graph_def, + const TransformFuncContext& context, + GraphDef* result) { + result->Clear(); + return Status::OK(); +} +} // namespace + +REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform", + test_empty_graph_transform); + +class TransformGraphTest : public ::testing::Test { + protected: + void TestConstantFolding() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = + Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = + Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const); + + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + + Output mul = + Mul(root.WithOpName("output_expect_remains"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + string graph_def_serialized; + graph_def.SerializeToString(&graph_def_serialized); + const string dir = testing::TmpDir(); + const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb"); + const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb"); + TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb, + graph_def_serialized)); + + std::vector args = {"some_binary", + "--in_graph=" + in_filename_pb, + "--out_graph=" + out_filename_pb, + "--inputs=placeholder_expect_remains", + "--outputs=output_expect_remains", + "--transforms=fold_constants"}; + const int argc = 6; + EXPECT_EQ(argc, args.size()); + char* argv[argc]; + std::vector char_strings; + for (int i = 0; i < argc; ++i) { + string arg = args[i]; + char* char_string = new char[arg.size() + 1]; + std::copy_n(arg.c_str(), arg.size() + 1, char_string); + argv[i] = char_string; + char_strings.push_back(char_string); + } + ParseFlagsAndTransformGraph(argc, argv, false); + for (char* char_string : char_strings) { + delete[] char_string; + } + + GraphDef out_graph_def; + TF_EXPECT_OK( + ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def)); + + std::map out_node_map; + graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map); + + for (const NodeDef& node : out_graph_def.node()) { + const StringPiece name(node.name()); + const int occurrence_count = out_node_map.count(node.name()); + if (name.ends_with("expect_removed")) { + EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name(); + } + if (name.ends_with("expect_remains")) { + EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name(); + } + } + } + + void TestTransformRegistration() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + EXPECT_EQ(1, graph_def.node().size()); + TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}}, + &graph_def)); + EXPECT_EQ(0, graph_def.node().size()); + + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + Status no_such_status = + TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def); + EXPECT_TRUE( + StringPiece(no_such_status.ToString()).contains("not recognized")); + } + + void TestParseTransformParameters() { + TransformParameters params_list; + + ParseTransformParameters("foo", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_TRUE(params_list[0].second.empty()); + + ParseTransformParameters("foo bar", ¶ms_list); + EXPECT_EQ(2, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_TRUE(params_list[0].second.empty()); + EXPECT_EQ("bar", params_list[1].first); + EXPECT_TRUE(params_list[1].second.empty()); + + ParseTransformParameters("foo() bar()", ¶ms_list); + EXPECT_EQ(2, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_TRUE(params_list[0].second.empty()); + EXPECT_EQ("bar", params_list[1].first); + EXPECT_TRUE(params_list[1].second.empty()); + + ParseTransformParameters("foo(bob_something=sue)", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_EQ(1, params_list[0].second.count("bob_something")); + EXPECT_EQ(1, params_list[0].second["bob_something"].size()); + EXPECT_EQ("sue", params_list[0].second["bob_something"][0]); + + ParseTransformParameters("bar(a=1, b=2, a=3)", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("bar", params_list[0].first); + EXPECT_EQ(1, params_list[0].second.count("a")); + EXPECT_EQ(2, params_list[0].second["a"].size()); + EXPECT_EQ("1", params_list[0].second["a"][0]); + EXPECT_EQ("3", params_list[0].second["a"][1]); + EXPECT_EQ(1, params_list[0].second.count("b")); + EXPECT_EQ(1, params_list[0].second["b"].size()); + EXPECT_EQ("2", params_list[0].second["b"][0]); + + ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("bar", params_list[0].first); + EXPECT_EQ(1, params_list[0].second.count("a")); + EXPECT_EQ(2, params_list[0].second["a"].size()); + EXPECT_EQ("1", params_list[0].second["a"][0]); + EXPECT_EQ("3", params_list[0].second["a"][1]); + EXPECT_EQ(1, params_list[0].second.count("b")); + EXPECT_EQ(1, params_list[0].second["b"].size()); + EXPECT_EQ("1,2,3", params_list[0].second["b"][0]); + } + + void TestShouldIgnoreErrors() { + bool ignore_errors; + TF_EXPECT_OK( + ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors)); + EXPECT_TRUE(ignore_errors); + + TF_EXPECT_OK( + ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors)); + EXPECT_FALSE(ignore_errors); + + TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors)); + EXPECT_FALSE(ignore_errors); + + EXPECT_FALSE( + ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok()); + } +}; + +TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); } + +TEST_F(TransformGraphTest, TestTransformRegistration) { + TestTransformRegistration(); +} + +TEST_F(TransformGraphTest, TestParseTransformParameters) { + TestParseTransformParameters(); +} + +TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); } + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 72664eee9ba..0a0b0f01a5f 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -15,12 +15,50 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/transform_utils.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/public/session.h" namespace tensorflow { namespace graph_transforms { +namespace { +inline bool IsMerge(const NodeDef& node_def) { + return node_def.op() == "Merge" || node_def.op() == "RefMerge"; +} + +void RecordMatchedNodes(const NodeMatch& match, + std::set* matched_nodes) { + matched_nodes->insert(match.node.name()); + for (const NodeMatch& input_match : match.inputs) { + RecordMatchedNodes(input_match, matched_nodes); + } +} + +inline uint64 Hash64String(const string& input) { + return Hash64(input.data(), input.size()); +} +} // namespace + +void MatchedNodesAsArray(const NodeMatch& match, std::vector* result) { + std::set found_nodes; + std::vector current_matches = {match}; + while (!current_matches.empty()) { + std::vector next_matches; + for (const NodeMatch& current_match : current_matches) { + if (found_nodes.count(current_match.node.name())) { + continue; + } + found_nodes.insert(current_match.node.name()); + result->push_back(current_match.node); + for (const NodeMatch& input_match : current_match.inputs) { + next_matches.push_back(input_match); + } + } + current_matches = next_matches; + } +} + void MapNamesToNodes(const GraphDef& graph_def, std::map* result) { for (const NodeDef& node : graph_def.node()) { @@ -28,7 +66,19 @@ void MapNamesToNodes(const GraphDef& graph_def, } } -void NodeNamePartsFromInput(string input_name, string* prefix, +void MapNodesToOutputs(const GraphDef& graph_def, + std::map>* result) { + std::map node_map; + MapNamesToNodes(graph_def, &node_map); + for (const NodeDef& node : graph_def.node()) { + for (const string& input : node.input()) { + string input_node_name = NodeNameFromInput(input); + (*result)[input_node_name].push_back(&node); + } + } +} + +void NodeNamePartsFromInput(const string& input_name, string* prefix, string* node_name, string* suffix) { std::vector input_parts = str_util::Split(input_name, ':'); if (input_parts.size() < 2) { @@ -45,7 +95,7 @@ void NodeNamePartsFromInput(string input_name, string* prefix, *node_name = node_name_piece.ToString(); } -string NodeNameFromInput(string input_name) { +string NodeNameFromInput(const string& input_name) { string prefix; string node_name; string suffix; @@ -53,6 +103,57 @@ string NodeNameFromInput(string input_name) { return node_name; } +string CanonicalInputName(const string& input_name) { + string prefix; + string node_name; + string suffix; + NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); + if (suffix == "") { + suffix = ":0"; + } + return prefix + node_name + suffix; +} + +uint64 HashNodeDef(const NodeDef& node) { + uint64 hash = Hash64String(node.op()); + hash = Hash64Combine(hash, Hash64String(node.name())); + for (const string& input : node.input()) { + hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input))); + } + hash = Hash64Combine(hash, Hash64String(node.device())); + std::vector attr_names; + attr_names.reserve(node.attr().size()); + for (const auto& attr : node.attr()) { + attr_names.push_back(attr.first); + } + std::sort(attr_names.begin(), attr_names.end()); + string attr_serialized; + for (const string& attr_name : attr_names) { + auto attr = node.attr().at(attr_name); + attr.SerializeToString(&attr_serialized); + hash = Hash64Combine(hash, Hash64String(attr_serialized)); + } + return hash; +} + +void AddNodeInput(const string& input_name, NodeDef* node) { + *(node->mutable_input()->Add()) = input_name; +} + +void CopyNodeAttr(const NodeDef& source, const string& source_key, + const string& dest_key, NodeDef* dest) { + CHECK_NE(0, source.attr().count(source_key)) + << "No key '" << source_key << "' found in " << source.DebugString(); + (*(dest->mutable_attr()))[dest_key].CopyFrom(source.attr().at(source_key)); +} + +Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) { + TensorProto tensor_proto = node.attr().at(key).tensor(); + Tensor tensor; + CHECK(tensor.FromProto(tensor_proto)); + return tensor; +} + void FilterGraphDef(const GraphDef& input_graph_def, std::function selector, GraphDef* output_graph_def) { @@ -77,5 +178,425 @@ void RemoveAttributes(const GraphDef& input_graph_def, } } +Status SortByExecutionOrder(const GraphDef& input_graph_def, + GraphDef* output_graph_def) { + const int num_nodes = input_graph_def.node_size(); + std::vector ready; + std::vector pending_count; + pending_count.reserve(num_nodes); + std::vector> outputs(num_nodes); + + std::map name_index; + for (int i = 0; i < input_graph_def.node_size(); ++i) { + const NodeDef& node(input_graph_def.node(i)); + name_index[node.name()] = i; + } + + // Parse the inputs for each node. + for (int n = 0; n < num_nodes; ++n) { + const NodeDef& node_def(input_graph_def.node(n)); + if (IsMerge(node_def)) { + // for merge only wait for one non-control input. + int32 num_control_edges = 0; + for (int i = 0; i < node_def.input_size(); ++i) { + StringPiece input_name(node_def.input(i)); + if (input_name.starts_with("^")) { + num_control_edges++; + } + } + pending_count.push_back(num_control_edges + 1); + } else { + pending_count.push_back(node_def.input_size()); + } + if (node_def.input_size() == 0) { + ready.push_back(n); + continue; + } + for (int i = 0; i < node_def.input_size(); ++i) { + const string& input_name = node_def.input(i); + const string& input_node_name = NodeNameFromInput(input_name); + if (!name_index.count(input_node_name)) { + return errors::InvalidArgument("Node '", node_def.name(), + "': Unknown input node '", + node_def.input(i), "'"); + } + outputs[name_index[input_node_name]].push_back(n); + } + } + + int processed = 0; + output_graph_def->Clear(); + // Process the NodeDefs in topological order. + // Code above sets this up by filling in ready_ with nodes that have no + // inputs, pending_counts_ with the number of inputs for each node and + // outputs_ with the outputs of each node. + while (!ready.empty()) { + int o = ready.back(); + ready.pop_back(); + ++processed; + const NodeDef& node_def(input_graph_def.node(o)); + output_graph_def->mutable_node()->Add()->CopyFrom(node_def); + + // Update pending_count for outputs. + for (size_t i = 0; i < outputs[o].size(); ++i) { + const int output = outputs[o][i]; + pending_count[output]--; + if (pending_count[output] == 0) { + ready.push_back(output); + } + } + } + + if (processed < input_graph_def.node_size()) { + return errors::InvalidArgument(input_graph_def.node_size() - processed, + " nodes in a cycle"); + } + return Status::OK(); +} + +string OpTypePattern::DebugString() const { + string result = "{" + op + ", {"; + for (const OpTypePattern& input : inputs) { + result += input.DebugString() + ","; + } + result += "}}"; + return result; +} + +string NodeMatch::DebugString() const { + string result = "{"; + result += node.DebugString(); + result += ", {"; + for (const NodeMatch& input : inputs) { + result += input.DebugString() + ","; + } + result += "}}"; + return result; +} + +GraphMatcher::GraphMatcher(const GraphDef& graph_def) { + SortByExecutionOrder(graph_def, &graph_def_); + MapNamesToNodes(graph_def_, &node_map_); +} + +Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern, + std::vector* matches) { + std::set matched_nodes; + for (const NodeDef& node : graph_def_.node()) { + // Skip any nodes that are already part of a match. + if (matched_nodes.count(node.name())) { + continue; + } + NodeMatch match; + if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) { + RecordMatchedNodes(match, &matched_nodes); + matches->push_back(match); + } + } + return Status::OK(); +} + +bool GraphMatcher::DoesOpTypeMatch( + const NodeDef& node, const OpTypePattern& pattern, + const std::set& previously_matched_nodes, NodeMatch* match) { + VLOG(1) << "Looking at node " << node.DebugString(); + VLOG(1) << "pattern=" << pattern.DebugString(); + VLOG(1) << "match=" << match->DebugString(); + if (previously_matched_nodes.count(node.name())) { + VLOG(1) << "node " << node.name() << " has been previously matched"; + return false; + } + bool pattern_matched = false; + if (pattern.op == "*") { + pattern_matched = true; + } else { + std::vector pattern_ops = str_util::Split(pattern.op, '|'); + for (const string& pattern_op : pattern_ops) { + if (node.op() == pattern_op) { + pattern_matched = true; + } + } + } + if (!pattern_matched) { + VLOG(1) << "node.op() != pattern.op()"; + return false; + } + match->node = node; + // Ignore any control inputs for pattern-matching purposes + std::vector non_control_inputs; + for (const string& input : node.input()) { + if (!input.empty() && (input[0] != '^')) { + non_control_inputs.push_back(input); + } + } + if (pattern.inputs.empty()) { + // If there are no inputs, assume that's the end of the pattern. + return true; + } + if (non_control_inputs.size() != pattern.inputs.size()) { + VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()"; + return false; + } + for (int i = 0; i < pattern.inputs.size(); ++i) { + const string& input_node_name = NodeNameFromInput(non_control_inputs[i]); + const NodeDef& input_node = *(node_map_[input_node_name]); + const OpTypePattern& input_pattern = pattern.inputs[i]; + match->inputs.push_back(NodeMatch()); + NodeMatch* input_match = &(match->inputs.back()); + if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes, + input_match)) { + return false; + } + } + return true; +} + +Status ReplaceMatchingOpTypes( + const GraphDef& input_graph_def, const OpTypePattern& pattern, + const std::function&, + const std::set&, std::vector*)>& + node_generator, + const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) { + // Start off by retrieving all the matching subgraphs. + GraphMatcher matcher(input_graph_def); + std::vector matches; + matcher.GetOpTypeMatches(pattern, &matches); + + // Do some housekeeping so we can easily look up the resulting matches given + // a node name. + std::set matched_nodes; + std::map matches_by_head_name; + for (const NodeMatch& match : matches) { + matches_by_head_name[match.node.name()] = &match; + RecordMatchedNodes(match, &matched_nodes); + } + std::map> outputs_map; + MapNodesToOutputs(input_graph_def, &outputs_map); + + // Go through all the nodes in the input graph, see if they are part of a + // match or if they can be left untouched. + output_graph_def->Clear(); + for (const NodeDef& input_node : input_graph_def.node()) { + if (matches_by_head_name.count(input_node.name())) { + // This node is the beginning of a match, so call the replacement function + // after setting up some information it will need. + const NodeMatch* match = matches_by_head_name[input_node.name()]; + std::vector matched_nodes_array; + MatchedNodesAsArray(*match, &matched_nodes_array); + // This tells us whether a node is part of the current match. + std::set matched_nodes_lookup; + for (const NodeDef& matched_node : matched_nodes_array) { + matched_nodes_lookup.insert(matched_node.name()); + } + // These are helper arrays that the replacement function can use to tell + // whether it can safely remove an internal node (because nothing outside + // of the match uses it) or whether external nodes depend on it. + std::set input_nodes; + std::set output_nodes; + for (const NodeDef& matched_node : matched_nodes_array) { + // Look through all of this node's inputs, and if any of them come from + // outside the match, then this should be noted as one of the external + // inputs of the subgraph. + for (const string& input_name : matched_node.input()) { + string input_node_name = NodeNameFromInput(input_name); + if (!matched_nodes_lookup.count(input_node_name)) { + input_nodes.insert(matched_node.name()); + } + } + // Do a reverse input lookup, to see which other nodes use the current + // one as an input. If any of those nodes are outside the match + // subgraph, then the current node is marked as an output node that + // shouldn't be removed. + if (outputs_map.count(matched_node.name())) { + for (const NodeDef* dependent_node : + outputs_map[matched_node.name()]) { + if (!matched_nodes_lookup.count(dependent_node->name())) { + output_nodes.insert(matched_node.name()); + } + } + } + } + // Call the generator function and add all the returned nodes to the + // graph. + std::vector new_nodes; + TF_RETURN_IF_ERROR( + node_generator(*match, input_nodes, output_nodes, &new_nodes)); + std::set new_node_names; + for (const NodeDef& new_node : new_nodes) { + new_node_names.insert(new_node.name()); + } + // Check to make sure the generator function preserved all of the nodes + // that are used elsewhere in the graph, and add them back in if not. + bool abort_replacement = false; + if (!options.allow_inconsistencies) { + for (const string& expected_output : output_nodes) { + if (!new_node_names.count(expected_output)) { + LOG(WARNING) << "Expected " << expected_output + << " to be preserved."; + abort_replacement = true; + } + } + } + if (abort_replacement) { + LOG(WARNING) << "Generator function didn't preserve needed nodes, " + << "copying old replacements back in instead."; + std::vector old_nodes; + MatchedNodesAsArray(*match, &old_nodes); + for (const NodeDef& old_node : old_nodes) { + NodeDef* added_node = output_graph_def->mutable_node()->Add(); + added_node->CopyFrom(old_node); + } + } else { + for (const NodeDef& new_node : new_nodes) { + NodeDef* added_node = output_graph_def->mutable_node()->Add(); + added_node->CopyFrom(new_node); + } + } + } else if (!matched_nodes.count(input_node.name())) { + // This node isn't part of any match, so just copy it over. + NodeDef* added_node = output_graph_def->mutable_node()->Add(); + added_node->CopyFrom(input_node); + } else { + // Do nothing, because this is an internal part of a matching subgraph, + // and so will have been replaced by a new replacement subgraph. + } + } + + return Status::OK(); +} + +Status RenameNodeInputs(const GraphDef& input_graph_def, + const std::map& inputs_to_rename, + GraphDef* output_graph_def) { + std::map>> + canonical_inputs_to_rename; + for (const auto& input_to_rename : inputs_to_rename) { + canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)] + .push_back({input_to_rename.first, input_to_rename.second}); + } + + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + new_node->mutable_input()->Clear(); + for (const string& input_name : node.input()) { + std::set already_visited; + string new_input_name = input_name; + while ( + canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) { + string input_node_name = NodeNameFromInput(new_input_name); + if (already_visited.count(input_node_name)) { + return errors::InvalidArgument( + "RenameNodeInputs argument contains a cycle for ", + input_node_name); + } + already_visited.insert(input_node_name); + bool any_match_found = false; + for (const std::pair& input_to_rename : + canonical_inputs_to_rename.at(input_node_name)) { + const string& source_name = input_to_rename.first; + const string& dest_name = input_to_rename.second; + bool is_match; + string match_name; + if (StringPiece(source_name).ends_with(":*")) { + is_match = true; + string prefix; + string unused_node_name; + string suffix; + NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name, + &suffix); + match_name = prefix + dest_name + suffix; + } else { + is_match = (CanonicalInputName(source_name) == + CanonicalInputName(new_input_name)); + match_name = dest_name; + } + if (is_match) { + new_input_name = match_name; + any_match_found = true; + } + } + if (!any_match_found) { + break; + } + } + *(new_node->mutable_input()->Add()) = new_input_name; + } + } + return Status::OK(); +} + +void CopyOriginalMatch(const NodeMatch& match, + std::vector* new_nodes) { + std::vector old_nodes; + MatchedNodesAsArray(match, &old_nodes); + for (const NodeDef& old_node : old_nodes) { + new_nodes->push_back(old_node); + } +} + +TransformRegistry* GetTransformRegistry() { + static TransformRegistry transform_registry; + return &transform_registry; +} + +void FindInvalidInputs(const GraphDef& graph_def, + std::vector>* invalid_inputs) { + std::map node_map; + MapNamesToNodes(graph_def, &node_map); + + for (const NodeDef& node : graph_def.node()) { + for (const string& input : node.input()) { + string input_node = NodeNameFromInput(input); + if (!node_map.count(input_node)) { + invalid_inputs->push_back({node.name(), input_node}); + } + } + } +} + +Status IsGraphValid(const GraphDef& graph_def) { + std::vector> invalid_inputs; + FindInvalidInputs(graph_def, &invalid_inputs); + if (!invalid_inputs.empty()) { + std::map node_map; + MapNamesToNodes(graph_def, &node_map); + for (const std::pair& invalid_input : invalid_inputs) { + LOG(ERROR) << "Invalid input " << invalid_input.second << " for node " + << invalid_input.first << " - " + << node_map[invalid_input.first]->DebugString(); + } + return errors::Internal( + "Invalid graph with inputs referring to nonexistent nodes"); + } + return Status::OK(); +} + +int CountParameters(const TransformFuncContext& context, const string& name) { + if (context.params.count(name)) { + return context.params.at(name).size(); + } else { + return 0; + } +} + +Status GetExactlyOneParameter(const TransformFuncContext& context, + const string& name, const string& default_value, + string* result) { + const int params_count = CountParameters(context, name); + if (params_count == 0) { + *result = default_value; + return Status::OK(); + } else if (params_count == 1) { + *result = context.params.at(name).at(0); + return Status::OK(); + } else { + return errors::InvalidArgument("Expected a single '", name, + "' parameter, but found ", params_count, + " occurrences"); + } +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 7fb885b1ac2..7672011d6cf 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -16,7 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ +#include +#include + #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -26,18 +30,73 @@ namespace graph_transforms { void MapNamesToNodes(const GraphDef& graph_def, std::map* result); +// For every node in the graph create a list of the nodes that use it as an +// input. +void MapNodesToOutputs(const GraphDef& graph_def, + std::map>* result); + // NodeDef input strings can contain other information besides the name of an // input node. These include: // - Optional '^' prefix, indicating this is a control edge. // - The required name of the input node. -// - Option ':' suffix, showing which output of the node to use. +// - Optional ':' suffix, showing which output of the node to use. // This function takes a raw string, and breaks it into those component parts. -void NodeNamePartsFromInput(string input_name, string* prefix, +// The rules for inputs in function libraries are a bit more complex, and +// aren't handled by this routine. +void NodeNamePartsFromInput(const string& input_name, string* prefix, string* node_name, string* suffix); +// Adds a ':0' port to any inputs with no suffix, to make comparisons easier. +string CanonicalInputName(const string& input_name); + // Convenience function to strip the optional prefix and suffix components from // a string pulled from a NodeDef input, and return the plain node name. -string NodeNameFromInput(string input_name); +string NodeNameFromInput(const string& input_name); + +// Returns a stable hash for the contents of the NodeDef, so that equivalent +// nodes should have equal hashes. +uint64 HashNodeDef(const NodeDef& node); + +// Adds the given node name to the end of the node's inputs. +void AddNodeInput(const string& input_name, NodeDef* node); + +// Copies an attribute from one NodeDef to another. +void CopyNodeAttr(const NodeDef& source, const string& source_key, + const string& dest_key, NodeDef* dest); + +// Inserts a value into a NodeDef's map of attributes. +// This is a bit different than AddNodeAttr in node_def_util.h because it +// overwrites any existing attributes with the same key. +template +inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) { + AttrValue attr_value; + SetAttrValue(value, &attr_value); + auto* attr_map = node->mutable_attr(); + (*attr_map)[key] = attr_value; +} + +template +inline void SetNodeTensorAttr(const string& key, const Tensor& tensor, + NodeDef* node) { + TensorProto tensor_proto; + tensor.AsProtoTensorContent(&tensor_proto); + SetNodeAttr(key, tensor_proto, node); +} + +// Inserts a Tensor into the specified attribute of a NodeDef. +template +inline void SetNodeTensorAttr(const string& key, const TensorShape& shape, + const std::vector& values, NodeDef* node) { + const DataType dtype = DataTypeToEnum::v(); + CHECK_EQ(shape.num_elements(), values.size()); + Tensor tensor(dtype, shape); + T* dest_data = tensor.flat().data(); + std::copy_n(values.data(), values.size(), dest_data); + SetNodeTensorAttr(key, tensor, node); +} + +// Retrieves a tensor value from a NodeDef attribute. +Tensor GetNodeTensorAttr(const NodeDef& node, const string& key); // Creates a copy of the input GraphDef, but only containing the nodes where the // supplied selector function returned true. @@ -51,6 +110,144 @@ void RemoveAttributes(const GraphDef& input_graph_def, const std::vector& attributes, GraphDef* output_graph_def); +// For a lot of replacement and matching operations it's useful to have the +// nodes processed in a controlled order, so this does a topological sort to +// ensure that nodes always appear in the GraphDef.node list after their inputs. +Status SortByExecutionOrder(const GraphDef& input_graph_def, + GraphDef* output_graph_def); + +// Finds inputs that refer to nodes that are not in the graph. +void FindInvalidInputs(const GraphDef& graph_def, + std::vector>* invalid_inputs); + +// Returns a descriptive error status if there are problems spotted with the +// graph. +Status IsGraphValid(const GraphDef& graph_def); + +// This is used to spot particular subgraphs in a larger model. To use it, +// create a pattern like: +// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); +// This defines a subgraph where a Conv2D has a ResizeBilinear input, which +// pulls from a MirrorPad op. +// Regular expressions aren't supported for the op names, but you can use "*" to +// match any op. You can also use | as a separator to match multiple op names, +// like "Reshape|Concat|Conv2D". +struct OpTypePattern { + string op; + std::vector inputs; + string DebugString() const; +}; + +// Returns a sub-graph of nodes that match a pattern. +struct NodeMatch { + NodeMatch() : node() {} + NodeDef node; + std::vector inputs; + string DebugString() const; +}; + +// Utility class to spot subgraphs matching particular patterns. +class GraphMatcher { + public: + GraphMatcher(const GraphDef& graph_def); + + // Sorts the input nodes into execution order, and then skips any previously + // matches so that no node appears in more than one match. The NodeDef + // pointers contained in the results are owned by the GraphMatcher object, and + // so will be invalid after its lifetime. + Status GetOpTypeMatches(const OpTypePattern& pattern, + std::vector* matches); + + private: + bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern, + const std::set& previously_matched_nodes, + NodeMatch* match); + + GraphDef graph_def_; + std::map node_map_; +}; + +struct ReplaceMatchingOpTypesOptions { + // Whether to raise an error if the graph is left with dangling inputs. If you + // enable this option, you must fix inconsistencies in a later pass. + bool allow_inconsistencies; +}; + +// Replaces all of the matching sub-graphs with new ops. This calls into the +// given function, and expects to receive a set of new nodes to replace each +// matched sub-graph. It has some logic to protect the integrity of the +// resulting graph, for example making sure that nodes needed by other nodes +// outside the sub-graph aren't removed. These are passed in as the set of +// outputs, and nodes with the same names must be added to the new nodes +// produced by the replacement function. Many of these checks can be disabled +// by setting allow_inconsistencies to true in the options, but then it's the +// caller's responsibility to patch up any problems before passing on the graph +// to others. There's more comprehensive usage documentation in the README. +Status ReplaceMatchingOpTypes( + const GraphDef& input_graph_def, const OpTypePattern& pattern, + const std::function&, + const std::set&, std::vector*)>& + node_generator, + const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def); + +// Returns a list of the unique nodes found in this match. +void MatchedNodesAsArray(const NodeMatch& match, std::vector* result); + +// Changes all input references to a particular node name. +Status RenameNodeInputs(const GraphDef& input_graph_def, + const std::map& inputs_to_rename, + GraphDef* output_graph_def); + +// Utility function that copies all the nodes found in a match into the +// new_nodes list. This is useful in replacement functions when you decide to +// leave the original matched subgraph untouched and make no changes. +void CopyOriginalMatch(const NodeMatch& match, std::vector* new_nodes); + +// Holds information that's needed for transform functions. +typedef std::map> TransformFuncParameters; +struct TransformFuncContext { + std::vector input_names; + std::vector output_names; + TransformFuncParameters params; +}; + +// Returns how many occurrences of the given parameter are present. +int CountParameters(const TransformFuncContext& context, const string& name); + +// Gets a simple occurrence of a parameter, using a default if it isn't present. +Status GetExactlyOneParameter(const TransformFuncContext& context, + const string& name, const string& default_value, + string* result); + +// This is the function API for all graph transformations, taking an input +// GraphDef and other arguments, and returning a transformed GraphDef. +typedef std::function + TransformFunc; + +// To add a new graph transform function, call the macro: +// REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants); +// Under the hood this adds the function to the list of known transforms, so you +// just need to link in the .cc file with your registration call to have access +// to it through the command line tool. +// The rest of the machinery below is to enable that automagical registration. +typedef std::map TransformRegistry; +TransformRegistry* GetTransformRegistry(); +class TransformRegistrar { + public: + TransformRegistrar(const string& name, TransformFunc transform_func) { + TransformRegistry* transform_registry = GetTransformRegistry(); + (*transform_registry)[name] = transform_func; + } +}; +#define REGISTER_GRAPH_TRANSFORM(name, func) \ + REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func) +#define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \ + REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) +#define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) \ + static tensorflow::graph_transforms::TransformRegistrar \ + registrar__body__##ctr##__object(name, func); + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index 1c1f4d97ed6..3e9f661f672 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -52,8 +52,8 @@ class TransformUtilsTest : public ::testing::Test { GraphDef graph_def; TF_ASSERT_OK(root.ToGraphDef(&graph_def)); std::map node_map; - MapNamesToNodes(graph_def, &node_map); + EXPECT_EQ(1, node_map.count("a")); EXPECT_EQ(1, node_map.count("b")); EXPECT_EQ(1, node_map.count("add")); @@ -62,6 +62,52 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_EQ(0, node_map.count("no_such_node")); } + void TestMapNodesToOutputs() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add"), a_const, b_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + std::map> outputs_map; + MapNodesToOutputs(graph_def, &outputs_map); + + EXPECT_EQ(1, outputs_map.count("a")); + EXPECT_EQ(1, outputs_map["a"].size()); + EXPECT_EQ("add", outputs_map["a"][0]->name()); + + EXPECT_EQ(1, outputs_map.count("b")); + EXPECT_EQ(1, outputs_map["b"].size()); + EXPECT_EQ("add", outputs_map["b"][0]->name()); + + EXPECT_EQ(1, outputs_map.count("add")); + EXPECT_EQ(1, outputs_map["add"].size()); + EXPECT_EQ("output", outputs_map["add"][0]->name()); + + EXPECT_EQ(1, outputs_map.count("placeholder")); + EXPECT_EQ(1, outputs_map["placeholder"].size()); + EXPECT_EQ("output", outputs_map["placeholder"][0]->name()); + + EXPECT_EQ(0, outputs_map.count("output")); + EXPECT_EQ(0, outputs_map.count("no_such_node")); + } + void TestNodeNamePartsFromInput() { string prefix; string node_name; @@ -101,6 +147,75 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42")); } + void TestCanonicalInputName() { + EXPECT_EQ("node_name:0", CanonicalInputName("node_name")); + EXPECT_EQ("node_name:0", CanonicalInputName("node_name:0")); + EXPECT_EQ("^node_name:0", CanonicalInputName("^node_name")); + EXPECT_EQ("^node_name:42", CanonicalInputName("^node_name:42")); + } + + void TestAddNodeInput() { + NodeDef node; + AddNodeInput("foo", &node); + EXPECT_EQ("foo", node.input(0)); + } + + void TestCopyNodeAttr() { + NodeDef node; + auto mutable_attr = node.mutable_attr(); + (*mutable_attr)["foo"].set_i(3); + + NodeDef copied_node; + CopyNodeAttr(node, "foo", "bar", &copied_node); + EXPECT_EQ(3, copied_node.attr().at("bar").i()); + } + + void TestSetNodeAttr() { + NodeDef node; + int32 value_i = 32; + SetNodeAttr("foo", value_i, &node); + EXPECT_EQ(32, node.attr().at("foo").i()); + string value_s = "some_value"; + SetNodeAttr("bar", value_s, &node); + EXPECT_EQ("some_value", node.attr().at("bar").s()); + } + + void TestSetNodeTensorAttr() { + NodeDef node; + SetNodeTensorAttr("foo", {3, 1}, {1, 2, 3}, &node); + TensorProto tensor_proto = node.attr().at("foo").tensor(); + Tensor tensor; + CHECK(tensor.FromProto(tensor_proto)); + EXPECT_EQ(DT_INT32, tensor.dtype()); + EXPECT_EQ(3, tensor.shape().dim_size(0)); + EXPECT_EQ(1, tensor.shape().dim_size(1)); + EXPECT_EQ(1, tensor.flat()(0)); + EXPECT_EQ(2, tensor.flat()(1)); + EXPECT_EQ(3, tensor.flat()(2)); + } + + void TestSetNodeTensorAttrWithTensor() { + NodeDef node; + Tensor input_tensor(DT_INT32, {4, 5}); + test::FillIota(&input_tensor, 1); + SetNodeTensorAttr("foo", input_tensor, &node); + TensorProto tensor_proto = node.attr().at("foo").tensor(); + Tensor tensor; + CHECK(tensor.FromProto(tensor_proto)); + test::ExpectTensorEqual(input_tensor, tensor); + } + + void TestGetNodeTensorAttr() { + NodeDef node; + Tensor input_tensor(DT_INT32, {4, 5}); + test::FillIota(&input_tensor, 1); + TensorProto tensor_proto; + input_tensor.AsProtoTensorContent(&tensor_proto); + SetNodeAttr("foo", tensor_proto, &node); + Tensor result = GetNodeTensorAttr(node, "foo"); + test::ExpectTensorEqual(input_tensor, result); + } + void TestFilterGraphDef() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) @@ -160,19 +275,679 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_EQ(nullptr, tensorflow::AttrSlice(*removed_placeholder).Find("dtype")); } + + void TestGetOpTypeMatches() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add"), a_const, b_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphMatcher matcher(graph_def); + + std::vector const_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches({"Const"}, &const_matches)); + EXPECT_EQ(2, const_matches.size()); + for (const NodeMatch& match : const_matches) { + EXPECT_EQ("Const", match.node.op()); + EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name())) + << "match.node.name()=" << match.node.name(); + } + + std::vector add_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add"}, &add_matches)); + EXPECT_EQ(1, add_matches.size()); + EXPECT_EQ("Add", add_matches[0].node.op()); + EXPECT_EQ("add", add_matches[0].node.name()); + + std::vector add_child_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}}, + &add_child_matches)); + EXPECT_EQ(1, add_child_matches.size()); + EXPECT_EQ("Add", add_child_matches[0].node.op()); + EXPECT_EQ("add", add_child_matches[0].node.name()); + EXPECT_EQ(2, add_child_matches[0].inputs.size()); + for (const NodeMatch& match : add_child_matches[0].inputs) { + EXPECT_EQ("Const", match.node.op()); + EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name())) + << "match.node.name()=" << match.node.name(); + } + + std::vector no_such_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches({"NoSuch"}, &no_such_matches)); + EXPECT_EQ(0, no_such_matches.size()); + + std::vector all_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches( + {"Mul", {{"Add", {{"Const"}, {"Const"}}}, {"Placeholder"}}}, + &all_matches)); + EXPECT_EQ(1, all_matches.size()); + EXPECT_EQ("Mul", all_matches[0].node.op()); + EXPECT_EQ("output", all_matches[0].node.name()); + EXPECT_EQ(2, all_matches[0].inputs.size()); + EXPECT_EQ("Add", all_matches[0].inputs[0].node.op()); + EXPECT_EQ("add", all_matches[0].inputs[0].node.name()); + EXPECT_EQ(2, all_matches[0].inputs[0].inputs.size()); + EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[0].node.op()); + EXPECT_EQ("a", all_matches[0].inputs[0].inputs[0].node.name()); + EXPECT_EQ(0, all_matches[0].inputs[0].inputs[0].inputs.size()); + EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[1].node.op()); + EXPECT_EQ("b", all_matches[0].inputs[0].inputs[1].node.name()); + EXPECT_EQ(0, all_matches[0].inputs[0].inputs[1].inputs.size()); + EXPECT_EQ("Placeholder", all_matches[0].inputs[1].node.op()); + EXPECT_EQ("placeholder", all_matches[0].inputs[1].node.name()); + EXPECT_EQ(0, all_matches[0].inputs[1].inputs.size()); + + std::vector wildcard_matches; + TF_ASSERT_OK( + matcher.GetOpTypeMatches({"*", {{"*"}, {"*"}}}, &wildcard_matches)); + EXPECT_EQ(1, wildcard_matches.size()); + EXPECT_EQ("Add", wildcard_matches[0].node.op()); + EXPECT_EQ("Const", wildcard_matches[0].inputs[0].node.op()); + EXPECT_EQ("a", wildcard_matches[0].inputs[0].node.name()); + EXPECT_EQ("Const", wildcard_matches[0].inputs[1].node.op()); + EXPECT_EQ("b", wildcard_matches[0].inputs[1].node.name()); + + std::vector or_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add|Mul"}, &or_matches)); + EXPECT_EQ(2, or_matches.size()); + EXPECT_EQ("Add", or_matches[0].node.op()); + EXPECT_EQ("add", or_matches[0].node.name()); + EXPECT_EQ("Mul", or_matches[1].node.op()); + EXPECT_EQ("output", or_matches[1].node.name()); + } + + void TestGetOpTypeMatchesDAG() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Output add = Add(root.WithOpName("add"), a_const, a_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphMatcher matcher(graph_def); + + std::vector add_matches; + TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}}, + &add_matches)); + EXPECT_EQ(1, add_matches.size()); + EXPECT_EQ("Add", add_matches[0].node.op()); + EXPECT_EQ("add", add_matches[0].node.name()); + EXPECT_EQ("Const", add_matches[0].inputs[0].node.op()); + EXPECT_EQ("a", add_matches[0].inputs[0].node.name()); + EXPECT_EQ("Const", add_matches[0].inputs[1].node.op()); + EXPECT_EQ("a", add_matches[0].inputs[1].node.name()); + } + + void TestReplaceMatchingOpTypes() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add"), a_const, b_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef replaced_graph_def; + TF_ASSERT_OK(ReplaceMatchingOpTypes( + graph_def, {"*"}, + [](const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + NodeDef original_copy; + original_copy.CopyFrom(match.node); + const string original_name = match.node.name(); + original_copy.set_name(original_name + "_before_identity"); + new_nodes->push_back(original_copy); + + NodeDef identity_node; + identity_node.set_op("Identity"); + identity_node.set_name(original_name); + *(identity_node.mutable_input()->Add()) = original_copy.name(); + new_nodes->push_back(identity_node); + + return Status::OK(); + }, + {}, &replaced_graph_def)); + + EXPECT_EQ(10, replaced_graph_def.node_size()); + for (const NodeDef& node : replaced_graph_def.node()) { + if (node.name() == "output") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("output_before_identity", node.input(0)); + } else if (node.name() == "output_before_identity") { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ("add", node.input(0)); + EXPECT_EQ("placeholder", node.input(1)); + } else if (node.name() == "placeholder") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("placeholder_before_identity", node.input(0)); + } else if (node.name() == "placeholder_before_identity") { + EXPECT_EQ("Placeholder", node.op()); + } else if (node.name() == "add") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("add_before_identity", node.input(0)); + } else if (node.name() == "add_before_identity") { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ("a", node.input(0)); + EXPECT_EQ("b", node.input(1)); + } else if (node.name() == "a") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("a_before_identity", node.input(0)); + } else if (node.name() == "a_before_identity") { + EXPECT_EQ("Const", node.op()); + } else if (node.name() == "b") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("b_before_identity", node.input(0)); + } else if (node.name() == "b_before_identity") { + EXPECT_EQ("Const", node.op()); + } else { + EXPECT_EQ(true, false) << "Unexpected node name found: " << node.name(); + } + } + } + + void TestMatchedNodesAsArray() { + NodeMatch fourth; + fourth.node.set_name("fourth"); + + NodeMatch second; + second.node.set_name("second"); + second.inputs.push_back(fourth); + + NodeMatch third; + third.node.set_name("third"); + third.inputs.push_back(fourth); + + NodeMatch first; + first.node.set_name("first"); + first.inputs.push_back(second); + first.inputs.push_back(third); + + std::vector result; + MatchedNodesAsArray(first, &result); + + EXPECT_EQ(4, result.size()); + EXPECT_EQ("first", result[0].name()); + EXPECT_EQ("second", result[1].name()); + EXPECT_EQ("third", result[2].name()); + EXPECT_EQ("fourth", result[3].name()); + } + + void TestRenameNodeInputs() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add"), a_const, a_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef renamed_graph_def; + TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def)); + + std::map node_map; + MapNamesToNodes(renamed_graph_def, &node_map); + EXPECT_EQ("b", node_map.at("add")->input(0)); + EXPECT_EQ("b", node_map.at("add")->input(1)); + } + + void TestRenameNodeInputsWithRedirects() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Tensor c_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&c_data, 1.0f); + Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data)); + + Output add = Add(root.WithOpName("add"), a_const, b_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef renamed_graph_def; + TF_ASSERT_OK(RenameNodeInputs( + graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}}, + &renamed_graph_def)); + + std::map node_map; + MapNamesToNodes(renamed_graph_def, &node_map); + EXPECT_EQ("c", node_map.at("add")->input(0)); + EXPECT_EQ("b", node_map.at("add")->input(1)); + } + + void TestRenameNodeInputsWithCycle() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + Tensor c_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&c_data, 1.0f); + Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data)); + + Output add = Add(root.WithOpName("add"), a_const, b_const); + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + Output mul = Mul(root.WithOpName("output"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef renamed_graph_def; + Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}}, + &renamed_graph_def); + EXPECT_FALSE(rename_status.ok()); + } + + void TestRenameNodeInputsWithWildcard() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + + QuantizeV2 quantize_a(root.WithOpName("quantize_a"), a_const, a_const, + a_const, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); + + QuantizeV2 quantize_b(root.WithOpName("quantize_b"), b_const, b_const, + b_const, DT_QUINT8, + QuantizeV2::Attrs().Mode("MIN_FIRST")); + + Output add = Add(root.WithOpName("add"), quantize_a.output_min, + quantize_a.output_max); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef renamed_graph_def; + TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}}, + &renamed_graph_def)); + + std::map node_map; + MapNamesToNodes(renamed_graph_def, &node_map); + EXPECT_EQ("quantize_b:1", node_map.at("add")->input(0)); + EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1)); + } + + void TestFindInvalidInputs() { + GraphDef graph_def; + + NodeDef* mul_node = graph_def.mutable_node()->Add(); + mul_node->set_op("Mul"); + mul_node->set_name("mul_node"); + *(mul_node->mutable_input()->Add()) = "add_node1"; + *(mul_node->mutable_input()->Add()) = "add_node2:0"; + *(mul_node->mutable_input()->Add()) = "^const_node1:0"; + + NodeDef* add_node1 = graph_def.mutable_node()->Add(); + add_node1->set_op("Add"); + add_node1->set_name("add_node1"); + *(add_node1->mutable_input()->Add()) = "missing_input1"; + *(add_node1->mutable_input()->Add()) = "const_node1:0"; + *(add_node1->mutable_input()->Add()) = "missing_input2"; + + NodeDef* add_node2 = graph_def.mutable_node()->Add(); + add_node2->set_op("Add"); + add_node2->set_name("add_node2"); + *(add_node2->mutable_input()->Add()) = "missing_input3"; + *(add_node2->mutable_input()->Add()) = "const_node1:0"; + *(add_node2->mutable_input()->Add()) = "^const_node2"; + + NodeDef* const_node1 = graph_def.mutable_node()->Add(); + const_node1->set_op("Const"); + const_node1->set_name("const_node1"); + + NodeDef* const_node2 = graph_def.mutable_node()->Add(); + const_node2->set_op("Const"); + const_node2->set_name("const_node2"); + + std::vector> invalid_inputs; + FindInvalidInputs(graph_def, &invalid_inputs); + EXPECT_EQ(3, invalid_inputs.size()); + for (const std::pair& invalid_input : invalid_inputs) { + EXPECT_TRUE((invalid_input.first == "add_node1") || + (invalid_input.first == "add_node2")); + if (invalid_input.first == "add_node1") { + EXPECT_TRUE((invalid_input.second == "missing_input1") || + (invalid_input.second == "missing_input2")) + << invalid_input.second; + } else if (invalid_input.first == "add_node2") { + EXPECT_EQ("missing_input3", invalid_input.second); + } + } + } + + void TestIsGraphValid() { + GraphDef invalid_graph_def; + + NodeDef* mul_node = invalid_graph_def.mutable_node()->Add(); + mul_node->set_op("Mul"); + mul_node->set_name("mul_node"); + *(mul_node->mutable_input()->Add()) = "add_node1"; + *(mul_node->mutable_input()->Add()) = "add_node2:0"; + *(mul_node->mutable_input()->Add()) = "^const_node1:0"; + + NodeDef* add_node1 = invalid_graph_def.mutable_node()->Add(); + add_node1->set_op("Add"); + add_node1->set_name("add_node1"); + *(add_node1->mutable_input()->Add()) = "missing_input1"; + *(add_node1->mutable_input()->Add()) = "const_node1:0"; + *(add_node1->mutable_input()->Add()) = "missing_input2"; + + NodeDef* add_node2 = invalid_graph_def.mutable_node()->Add(); + add_node2->set_op("Add"); + add_node2->set_name("add_node2"); + *(add_node2->mutable_input()->Add()) = "missing_input3"; + *(add_node2->mutable_input()->Add()) = "const_node1:0"; + *(add_node2->mutable_input()->Add()) = "^const_node2"; + + NodeDef* const_node1 = invalid_graph_def.mutable_node()->Add(); + const_node1->set_op("Const"); + const_node1->set_name("const_node1"); + + NodeDef* const_node2 = invalid_graph_def.mutable_node()->Add(); + const_node2->set_op("Const"); + const_node2->set_name("const_node2"); + + EXPECT_FALSE(IsGraphValid(invalid_graph_def).ok()); + + GraphDef valid_graph_def; + + NodeDef* const_node3 = valid_graph_def.mutable_node()->Add(); + const_node3->set_op("Const"); + const_node3->set_name("const_node2"); + + EXPECT_TRUE(IsGraphValid(valid_graph_def).ok()); + } + + void TestCopyOriginalMatch() { + NodeDef a; + a.set_op("Relu"); + a.set_name("a"); + AddNodeInput("b", &a); + + NodeDef b; + b.set_op("Const"); + b.set_name("b"); + + NodeMatch b_match; + b_match.node = b; + + NodeMatch a_match; + a_match.node = a; + a_match.inputs.push_back(b_match); + + std::vector new_nodes; + CopyOriginalMatch(a_match, &new_nodes); + EXPECT_EQ(2, new_nodes.size()); + EXPECT_EQ("a", new_nodes[0].name()); + EXPECT_EQ("Relu", new_nodes[0].op()); + EXPECT_EQ("b", new_nodes[1].name()); + EXPECT_EQ("Const", new_nodes[1].op()); + } + + void TestHashNodeDef() { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 10; + + auto a_root = tensorflow::Scope::NewRootScope(); + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(a_root.WithOpName("a"), Input::Initializer(a_data)); + GraphDef a_graph_def; + TF_ASSERT_OK(a_root.ToGraphDef(&a_graph_def)); + const NodeDef& a_node_def = a_graph_def.node(0); + + auto b_root = tensorflow::Scope::NewRootScope(); + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&b_data, 1.0f); + Output b_const = Const(b_root.WithOpName("a"), Input::Initializer(b_data)); + GraphDef b_graph_def; + TF_ASSERT_OK(b_root.ToGraphDef(&b_graph_def)); + const NodeDef& b_node_def = b_graph_def.node(0); + + auto c_root = tensorflow::Scope::NewRootScope(); + Tensor c_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&c_data, 2.0f); + Output c_const = Const(c_root.WithOpName("a"), Input::Initializer(c_data)); + GraphDef c_graph_def; + TF_ASSERT_OK(c_root.ToGraphDef(&c_graph_def)); + const NodeDef& c_node_def = c_graph_def.node(0); + + auto d_root = tensorflow::Scope::NewRootScope(); + Tensor d_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&d_data, 1.0f); + Output d_const = Const(d_root.WithOpName("d"), Input::Initializer(d_data)); + GraphDef d_graph_def; + TF_ASSERT_OK(d_root.ToGraphDef(&d_graph_def)); + const NodeDef& d_node_def = d_graph_def.node(0); + + auto e_root = tensorflow::Scope::NewRootScope(); + Tensor e_data(DT_INT32, TensorShape({width})); + test::FillIota(&e_data, 1); + Output e_const = Const(e_root.WithOpName("a"), Input::Initializer(e_data)); + GraphDef e_graph_def; + TF_ASSERT_OK(e_root.ToGraphDef(&e_graph_def)); + const NodeDef& e_node_def = e_graph_def.node(0); + + auto f_root = tensorflow::Scope::NewRootScope(); + Tensor f_data(DT_FLOAT, TensorShape({width - 1})); + test::FillIota(&f_data, 1.0f); + Output f_const = Const(f_root.WithOpName("a"), Input::Initializer(f_data)); + GraphDef f_graph_def; + TF_ASSERT_OK(f_root.ToGraphDef(&f_graph_def)); + const NodeDef& f_node_def = f_graph_def.node(0); + + auto g_root = tensorflow::Scope::NewRootScope(); + Tensor g_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&g_data, 1); + Output g_const = Const(g_root.WithOpName("a").WithDevice("some_device"), + Input::Initializer(g_data)); + GraphDef g_graph_def; + TF_ASSERT_OK(g_root.ToGraphDef(&g_graph_def)); + const NodeDef& g_node_def = g_graph_def.node(0); + + NodeDef relu1_node_def; + relu1_node_def.set_op("Relu"); + relu1_node_def.set_name("a"); + relu1_node_def.add_input("foo"); + + NodeDef relu2_node_def; + relu2_node_def.set_op("Relu"); + relu2_node_def.set_name("a"); + relu2_node_def.add_input("bar"); + + EXPECT_EQ(HashNodeDef(a_node_def), HashNodeDef(b_node_def)); + EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(c_node_def)); + EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(d_node_def)); + EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(e_node_def)); + EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(f_node_def)); + EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(g_node_def)); + EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(relu1_node_def)); + EXPECT_NE(HashNodeDef(relu1_node_def), HashNodeDef(relu2_node_def)); + } + + void TestCountParameters() { + TransformFuncContext context; + context.params.insert({"foo", {"a", "b"}}); + context.params.insert({"bar", {"c"}}); + EXPECT_EQ(2, CountParameters(context, "foo")); + EXPECT_EQ(1, CountParameters(context, "bar")); + EXPECT_EQ(0, CountParameters(context, "not_present")); + } + + void TestGetExactlyOneParameter() { + TransformFuncContext context; + context.params.insert({"foo", {"a", "b"}}); + context.params.insert({"bar", {"c"}}); + string value; + TF_EXPECT_OK(GetExactlyOneParameter(context, "bar", "d", &value)); + EXPECT_EQ("c", value); + EXPECT_FALSE(GetExactlyOneParameter(context, "foo", "d", &value).ok()); + TF_EXPECT_OK(GetExactlyOneParameter(context, "not_present", "d", &value)); + EXPECT_EQ("d", value); + } }; TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } +TEST_F(TransformUtilsTest, TestMapNodesToOutputs) { TestMapNodesToOutputs(); } + TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) { TestNodeNamePartsFromInput(); } +TEST_F(TransformUtilsTest, TestCanonicalInputName) { TestCanonicalInputName(); } + +TEST_F(TransformUtilsTest, TestAddNodeInput) { TestAddNodeInput(); } + +TEST_F(TransformUtilsTest, TestCopyNodeAttr) { TestCopyNodeAttr(); } + +TEST_F(TransformUtilsTest, TestSetNodeAttr) { TestSetNodeAttr(); } + +TEST_F(TransformUtilsTest, TestSetNodeTensorAttr) { TestSetNodeTensorAttr(); } + +TEST_F(TransformUtilsTest, TestSetNodeTensorAttrWithTensor) { + TestSetNodeTensorAttrWithTensor(); +} + +TEST_F(TransformUtilsTest, TestGetNodeTensorAttr) { TestGetNodeTensorAttr(); } + TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); } TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); } TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); } +TEST_F(TransformUtilsTest, TestGetOpTypeMatches) { TestGetOpTypeMatches(); } + +TEST_F(TransformUtilsTest, TestGetOpTypeMatchesDAG) { + TestGetOpTypeMatchesDAG(); +} + +TEST_F(TransformUtilsTest, TestReplaceMatchingOpTypes) { + TestReplaceMatchingOpTypes(); +} + +TEST_F(TransformUtilsTest, TestMatchedNodesAsArray) { + TestMatchedNodesAsArray(); +} + +TEST_F(TransformUtilsTest, TestRenameNodeInputs) { TestRenameNodeInputs(); } + +TEST_F(TransformUtilsTest, TestRenameNodeInputsWithRedirects) { + TestRenameNodeInputsWithRedirects(); +} + +TEST_F(TransformUtilsTest, TestRenameNodeInputsWithCycle) { + TestRenameNodeInputsWithCycle(); +} + +TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) { + TestRenameNodeInputsWithWildcard(); +} + +TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); } + +TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); } + +TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); } + +TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); } + +TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); } + +TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) { + TestGetExactlyOneParameter(); +} + } // namespace graph_transforms } // namespace tensorflow