Create Graph Transform Tool for rewriting model files.

Change: 142729497
This commit is contained in:
Pete Warden 2016-12-21 20:50:31 -08:00 committed by TensorFlower Gardener
parent be60473c88
commit 0f0e29e7ba
47 changed files with 9065 additions and 170 deletions

View File

@ -252,6 +252,16 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
bool DoConstantFolding(const ConstantFoldingOptions& opts, bool DoConstantFolding(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env, FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph) { Device* partition_device, Graph* graph) {
bool was_mutated;
Status unused_status = DoConstantFoldingWithStatus(
opts, function_library, env, partition_device, graph, &was_mutated);
return was_mutated;
}
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library,
Env* env, Device* partition_device,
Graph* graph, bool* was_mutated) {
DumpGraph("Before", graph); DumpGraph("Before", graph);
const FunctionLibraryDefinition* flib_def = nullptr; const FunctionLibraryDefinition* flib_def = nullptr;
@ -263,7 +273,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes); FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes);
if (constant_foldable_nodes.empty()) { if (constant_foldable_nodes.empty()) {
VLOG(1) << "No constant foldable nodes found"; VLOG(1) << "No constant foldable nodes found";
return false; *was_mutated = false;
// This is not an error, so return the status as OK.
return Status::OK();
} }
std::map<NodeAndOutput, Node*> tensors_to_fetch; std::map<NodeAndOutput, Node*> tensors_to_fetch;
@ -273,7 +285,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
if (tensors_to_fetch.empty()) { if (tensors_to_fetch.empty()) {
VLOG(1) << "No constant nodes found that feed into the original graph."; VLOG(1) << "No constant nodes found that feed into the original graph.";
return false; *was_mutated = false;
// This is not an error, so return the status as OK.
return Status::OK();
} }
VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : " VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
<< graph->num_node_ids(); << graph->num_node_ids();
@ -292,7 +306,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
{} /* inputs*/, tensors_to_fetch_names, &outputs); {} /* inputs*/, tensors_to_fetch_names, &outputs);
if (!s.ok()) { if (!s.ok()) {
VLOG(1) << "Could not fetch constants: " << s; VLOG(1) << "Could not fetch constants: " << s;
return false; *was_mutated = false;
// This is not an error, so return the status as OK.
return s;
} }
// Fetch the constant tensors and replace the corresponding tensors in the // Fetch the constant tensors and replace the corresponding tensors in the
@ -307,7 +323,8 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
DumpGraph("After", graph); DumpGraph("After", graph);
return num_nodes_replaced > 0; *was_mutated = (num_nodes_replaced > 0);
return Status::OK();
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,7 +29,16 @@ namespace tensorflow {
// and replaces those nodes with the result of the evaluation. // and replaces those nodes with the result of the evaluation.
// "partition_device", if non-null, is the device where all the graph nodes are // "partition_device", if non-null, is the device where all the graph nodes are
// assumed to execute. // assumed to execute.
// Returns true if and only if "graph" has been mutated. // Sets `was_mutated` to true if and only if "graph" has been mutated.
// The status is only set to a non-OK state if an unexpected error is hit
// running the graph.
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library,
Env* env, Device* partition_device,
Graph* graph, bool* was_mutated);
// Version of the function that doesn't return a Status, for backwards
// compatibility.
bool DoConstantFolding(const ConstantFoldingOptions& opts, bool DoConstantFolding(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env, FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph); Device* partition_device, Graph* graph);

View File

@ -228,8 +228,12 @@ TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) {
g->AddControlEdge(concat_send, g->sink_node()); g->AddControlEdge(concat_send, g->sink_node());
// The above concat should not have been constant folded. // The above concat should not have been constant folded.
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, bool was_mutated;
Env::Default(), nullptr, g)); Status status =
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, g, &was_mutated);
EXPECT_FALSE(was_mutated);
TF_EXPECT_OK(status);
} }
TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) { TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
@ -257,8 +261,12 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
g->AddControlEdge(times_two_send, g->sink_node()); g->AddControlEdge(times_two_send, g->sink_node());
// The above function call should not have been constant folded. // The above function call should not have been constant folded.
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, bool was_mutated;
Env::Default(), nullptr, g)); status =
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
Env::Default(), nullptr, g, &was_mutated);
EXPECT_FALSE(was_mutated);
EXPECT_FALSE(status.ok());
g_ = nullptr; g_ = nullptr;
} }
@ -337,10 +345,16 @@ TEST_F(ConstantFoldingTest, TestImmutableConst) {
auto result2 = ops::MatMul(root, result1, c); auto result2 = ops::MatMul(root, result1, c);
TF_ASSERT_OK(root.ToGraph(g)); TF_ASSERT_OK(root.ToGraph(g));
TestTFEnvironment test_env; TestTFEnvironment test_env;
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, bool was_mutated;
Env::Default(), nullptr, g)); Status status =
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, &test_env, DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
nullptr, g)); Env::Default(), nullptr, g, &was_mutated);
EXPECT_FALSE(was_mutated);
EXPECT_FALSE(status.ok());
status = DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
&test_env, nullptr, g, &was_mutated);
EXPECT_TRUE(was_mutated);
TF_EXPECT_OK(status);
} }
} // namespace } // namespace

View File

@ -41,11 +41,12 @@ template <typename Device, typename T>
class QuantizeV2Op : public OpKernel { class QuantizeV2Op : public OpKernel {
public: public:
explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
half_range_ = !std::is_signed<T>::value half_range_ =
? 0.0f !std::is_signed<T>::value
: (std::numeric_limits<T>::max() - ? 0.0f
std::numeric_limits<T>::min() + 1) / : (static_cast<double>(std::numeric_limits<T>::max()) -
2.0f; static_cast<double>(std::numeric_limits<T>::min()) + 1) /
2.0f;
string mode_string; string mode_string;
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
@ -90,7 +91,8 @@ class QuantizeV2Op : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
if (mode_ == QUANTIZE_MODE_MIN_COMBINED) { if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
const float scale_factor = const float scale_factor =
(std::numeric_limits<T>::max() - std::numeric_limits<T>::min()) / (static_cast<double>(std::numeric_limits<T>::max()) -
static_cast<double>(std::numeric_limits<T>::min())) /
(max_range - min_range); (max_range - min_range);
// Quantize: // Quantize:
@ -162,5 +164,8 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"), Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
QuantizeV2Op<CPUDevice, qint16>); QuantizeV2Op<CPUDevice, qint16>);
REGISTER_KERNEL_BUILDER(
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
QuantizeV2Op<CPUDevice, qint32>);
} // namespace tensorflow } // namespace tensorflow

View File

@ -47,6 +47,46 @@ TEST_F(QuantizedOpTest, QuantizeV2) {
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0)); test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
} }
TEST_F(QuantizedOpTest, QuantizeV2_32Bit) {
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<qint32>::v())
.Attr("mode", "MIN_FIRST")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
const int element_count = 8;
AddInputFromArray<float>(
TensorShape({element_count}),
{-500.0f, 0.0f, 1.0f, 1.25f, 1.75f, 127.0f, 255.0f, 500.0f});
AddInputFromArray<float>(TensorShape({1}), {-256.0f});
AddInputFromArray<float>(TensorShape({1}), {256.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT32, TensorShape({element_count}));
test::FillValues<qint32>(&expected,
{
std::numeric_limits<int32>::min(), 0,
static_cast<int32>(1.0f * (1 << 23)),
static_cast<int32>(1.25f * (1 << 23)),
static_cast<int32>(1.75f * (1 << 23)),
static_cast<int32>(127.0f * (1 << 23)),
static_cast<int32>(255.0f * (1 << 23)),
std::numeric_limits<int32>::max(),
});
// We expect there will be some fuzziness in the lower bits, since this is
// converting from float.
const int64 epsilon = 1 << 8;
const qint32* output_data = GetOutput(0)->flat<qint32>().data();
const qint32* expected_data = expected.flat<qint32>().data();
for (int i = 0; i < element_count; ++i) {
const int64 delta = output_data[i] - expected_data[i];
EXPECT_GT(epsilon, std::abs(delta))
<< "output_data[" << i << "]=" << output_data[i] << ", expected_data["
<< i << "]=" << expected_data[i] << ", delta=" << delta;
}
}
TEST_F(QuantizedOpTest, QuantizeV2Ports) { TEST_F(QuantizedOpTest, QuantizeV2Ports) {
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2") TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))

View File

@ -29,6 +29,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
], ],
@ -52,9 +53,23 @@ tf_cc_test(
) )
cc_library( cc_library(
name = "fold_constants_lib", name = "transforms_lib",
srcs = [ srcs = [
"fold_batch_norms.cc",
"fold_constants_lib.cc", "fold_constants_lib.cc",
"fold_old_batch_norms.cc",
"fuse_convolutions.cc",
"obsfucate_names.cc",
"quantize_nodes.cc",
"quantize_weights.cc",
"remove_attribute.cc",
"remove_device.cc",
"remove_nodes.cc",
"rename_attribute.cc",
"rename_op.cc",
"round_weights.cc",
"sort_by_execution_order.cc",
"strip_unused_nodes.cc",
], ],
hdrs = [ hdrs = [
"fold_constants_lib.h", "fold_constants_lib.h",
@ -65,20 +80,98 @@ cc_library(
":transform_utils", ":transform_utils",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
"//tensorflow/core/kernels:quantized_ops",
],
alwayslink = 1,
)
tf_cc_test(
name = "transforms_test",
size = "small",
srcs = [
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
"fold_old_batch_norms_test.cc",
"fuse_convolutions_test.cc",
"obsfucate_names_test.cc",
"quantize_nodes_test.cc",
"quantize_weights_test.cc",
"remove_attribute_test.cc",
"remove_device_test.cc",
"remove_nodes_test.cc",
"rename_attribute_test.cc",
"rename_op_test.cc",
"round_weights_test.cc",
"sort_by_execution_order_test.cc",
"strip_unused_nodes_test.cc",
],
deps = [
":transform_utils",
":transforms_lib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:quantized_ops",
],
)
cc_library(
name = "transform_graph_lib",
srcs = ["transform_graph.cc"],
hdrs = ["transform_graph.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":transform_utils",
":transforms_lib",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
)
# This library includes a main function, to make it easy to create other
# versions of the tool linked against different operator libs.
cc_library(
name = "transform_graph_main_lib",
srcs = ["transform_graph_main.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":transform_graph_lib",
":transforms_lib",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_binary(
name = "transform_graph",
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":transform_graph_main_lib",
], ],
) )
tf_cc_test( tf_cc_test(
name = "fold_constants_test", name = "transform_graph_test",
size = "small", size = "medium",
srcs = ["fold_constants_test.cc"], srcs = ["transform_graph_test.cc"],
deps = [ deps = [
":fold_constants_lib", ":transform_graph_lib",
":transform_utils", ":transform_utils",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:sendrecv_ops", "//tensorflow/cc:sendrecv_ops",
@ -95,23 +188,24 @@ tf_cc_test(
# This library includes a main function, to make it easy to create other # This library includes a main function, to make it easy to create other
# versions of the tool linked against different operator libs. # versions of the tool linked against different operator libs.
cc_library( cc_library(
name = "fold_constants_main_lib", name = "summarize_graph_main_lib",
srcs = ["fold_constants_tool.cc"], srcs = ["summarize_graph_main.cc"],
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":fold_constants_lib", ":transform_utils",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
], ],
) )
cc_binary( cc_binary(
name = "fold_constants_tool", name = "summarize_graph",
copts = tf_copts(), copts = tf_copts(),
linkstatic = 1, linkstatic = 1,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":fold_constants_main_lib", ":summarize_graph_main_lib",
], ],
) )

View File

@ -0,0 +1,858 @@
# Graph Transform Tool
## Table of Contents
* [Introduction](#introduction)
* [Using the Graph Transform Tool](#using-the-graph-transform-tool)
* [Inspecting Graphs](#inspecting-graphs)
* [Common Use Cases](#common-use-cases)
* [Optimizing for Deployment](#optimizing-for-deployment)
* [Fixing Missing Kernel Errors on
Mobile](#fixing-missing-kernel-errors-on-mobile)
* [Shrinking File Size](#shrinking-file-size)
* [Eight-bit Calculations](#eight-bit-calculations)
* [Transform Reference](#transform-reference)
* [fold_batch_norms](#fold_batch_norms)
* [fold_constants](#fold_constants)
* [fold_old_batch_norms](#fold_old_batch_norms)
* [fuse_convolutions](#fuse_convolutions)
* [merge_duplicate_nodes](#merge_duplicate_nodes)
* [obsfucate_names](#obsfucate_names)
* [quantize_nodes](#quantize_nodes)
* [quantize_weights](#quantize_weights)
* [remove_attribute](#remove_attribute)
* [remove_device](#remove_device)
* [remove_nodes](#remove_nodes)
* [rename_attribute](#rename_attribute)
* [rename_op](#rename_op)
* [round_weights](#round_weights)
* [sort_by_execution_order](#sort_by_execution_order)
* [strip_unused_nodes](#strip_unused_nodes)
* [Writing Your Own Transforms](#writing-your-own-transforms)
* [Transform Functions](#transform-functions)
* [Pattern Syntax](#pattern-syntax)
* [ReplaceMatchingOpTypes](#replacematchingoptypes)
* [Parameters](#parameters)
* [Function Libraries](#function-libraries)
* [Registering](#registering)
## Introduction
When you have finished training a model and want to deploy it in production,
you'll often want to modify it to better run in its final environment. For
example if you're targeting a phone you might want to shrink the file size by
quantizing the weights, or optimize away batch normalization or other
training-only features. The Graph Transform framework offers a suite of tools
for modifying computational graphs, and a framework to make it easy to write
your own modifications.
This guide is structured into three main parts, first giving some tutorials on
how to perform common tasks, second a reference covering all of the different
transformations that are included, together with the options that apply to them,
and third a guide to creating your own transforms.
## Using the Graph Transform Tool
The Graph Transform tool is designed to work on models that are saved as
GraphDef files, usually in a binary protobuf format. This is the low-level
definition of a TensorFlow computational graph, including a list of nodes and
the input and output connections between them. If you're using a Python API to
train your model, this will usually be saved out in the same directory as your
checkpoints, and usually has a '.pb' suffix.
If you want to work with the values of your trained parameters, for example to
quantize weights, you'll need to run
[tensorflow/python/tools/freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)
to convert the checkpoint values into embedded constants within the graph file
itself.
You call the Graph Transform tool itself like this:
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
strip_unused_nodes(type=float, shape="1,299,299,3") \
remove_nodes(op=Identity, op=CheckNumerics) \
fold_old_batch_norms \
'
```
The arguments here are specifying where to read the graph from, where to write
the transformed version to, what the input and output layers are, and what
transforms to modify the graph with. The transforms are given as a list of
names, and can each have arguments themselves. These transforms define the
pipeline of modifications that are applied in order to produce the output.
Sometimes you need some transforms to happen before others, and the ordering
within the list lets you specify which happen first.
## Inspecting Graphs
Many of the transforms that the tool supports need to know what the input and
output layers of the model are. The best source for these is the model training
process, where for a classifier the inputs will be the nodes that receive the
data from the training set, and the output will be the predictions. If you're
unsure, the
[summarize_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph.cc)
can inspect the model and provide guesses about likely input and output nodes,
as well as other information that's useful for debugging. Here's an example of
how to use it on the [Inception V3
graph](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz):
```bash
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=tensorflow_inception_graph.pb
```
## Common Use Cases
This section has small guides for some of the most frequently-used
transformation pipelines, aimed at users who want to quickly accomplish one of
these tasks. A lot of them will use the Inception V3 model for their examples,
which can be downloaded from
[http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
### Optimizing for Deployment
If you've finished training your model and want to deploy it on a server or a
mobile device, you'll want it to run as fast as possible, and with as few
non-essential dependencies as you can. This recipe removes all of the nodes that
aren't called during inference, shrinks expressions that are always constant
into single nodes, and optimizes away some multiply operations used during batch
normalization by pre-multiplying the weights for convolutions.
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
strip_unused_nodes(type=float, shape="1,299,299,3") \
remove_nodes(op=Identity, op=CheckNumerics) \
fold_constants(ignore_errors=true) \
fold_batch_norms \
fold_old_batch_norms\
'
```
The batch norm folding is included twice because there are two different flavors
of batch normalization used in TensorFlow. The older version was implemented
with a single BatchNormWithGlobalNormalization op, but it was deprecated in
favor of a more recent approach using individual ops to implement the same
computation. The two transforms are in there so that both styles are recognized
and optimized.
### Fixing Missing Kernel Errors on Mobile
The mobile version of TensorFlow is focused on inference, and so by default the
list of supported ops (defined in
[tensorflow/core/kernels/BUILD:android_extended_ops](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/BUILD#L2452)
for Bazel and
[tensorflow/contrib/makefile/tf_op_files.txt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/tf_op_files.txt)
for make builds) doesn't include a lot that are training related. This can cause
`No OpKernel was registered to support Op` errors when a GraphDef is loaded,
even if the op isn't going to be executed.
If you see this error and it's an op that you do actually want to run on mobile,
then you'll need to make local modifications to the build files to include the
right .cc file that defines it. In a lot of cases the op is just a vestigial
remnant from the training process though, and if that's true then you can run
the [strip_unused_nodes](#strip_unused_nodes), specifying the inputs and outputs
of your inference usage, to remove those unneccessary nodes:
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
strip_unused_nodes(type=float, shape="1,299,299,3") \
fold_constants \
fold_batch_norms \
fold_old_batch_norms\
'
```
### Shrinking File Size
If you're looking to deploy your model as part of a mobile app, then keeping the
download size as small as possible is important. For most TensorFlow models, the
largest contributors to the file size are the weights passed in to convolutional
and fully-connected layers, so anything that can reduce the storage size for
those is very useful. Luckily most neural networks are resistant to noise, so
it's possible to change the representation of those weights in a lossy way
without losing very much accuracy overall.
On both iOS and Android app packages are compressed before download, so the
simplest way to reduce the bandwidth your users need to receive your app is to
provide raw data that compresses more easily. By default the weights are stored
as floating-point values, and even tiny differences between numbers result in
very different bit patterns, and so these don't compress very well. If you round
the weights so that nearby numbers are stored as exactly the same values, the
resulting bit stream has a lot more repetition and so compresses down a lot more
effectively. To try this technique on your model, run the
[round_weights](#round_weights] transform.
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
round_weights(num_steps=256) \
'
```
You should see that the `optimized_inception_graph.pb` output file is the same
size as the input, but if you run zip on it to compress it, it's almost 70%
smaller than if you zip the original! The nice thing about this transform is
that it doesn't change the structure of the graph at all, so it's running
exactly the same operations and should have the same latency and memory usage as
before. You can adjust the `num_steps` parameter to control how many values each
weight buffer is rounded to, so lower numbers will increase the compression at
the cost of accuracy.
As a further step, you can store the weights into eight-bit values directly.
Here's the recipe for that:
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
quantize_weights \
'
```
You should see that the size of the output graph is about a quarter of the
original. The downside to this approach compared to round_weights is that extra
decompression ops are inserted to convert the eight-bit values back into
floating point, but optimizations in TensorFlow's runtime should ensure these
results are cached and so you shouldn't see the graph run any more slowly.
So far we've been concentrating on weights because those generally take up the
most space. If you have a graph with a lot of small nodes in it, the names of
those nodes can start to take up a noticeable amount of space too. To shrink
those down, you can run the [obsfucate_names](#obsfucate_names) transform, which
replaces all the names (except for inputs and outputs) with short, cryptic but
unique ids:
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
obsfucate_names \
'
```
### Eight-bit Calculations
For some platforms it's very helpful to be able to do as many calculations as
possible in eight-bit, rather than floating-point. The support for this in
TensorFlow is still experimental and evolving, but you can convert models into
quantized form using the graph transform tool:
```bash
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=tensorflow_inception_graph.pb \
--out_graph=optimized_inception_graph.pb \
--inputs='Mul:0' \
--outputs='softmax:0' \
--transforms='\
strip_unused_nodes(type=float, shape="1,299,299,3") \
remove_nodes(op=Identity, op=CheckNumerics) \
fold_old_batch_norms \
quantize_weights \
quantize_nodes \
strip_unused_nodes \
'
```
This process converts all the operations in the graph that have quantized
equivalents, and leaves the rest in floating point. Only a subset of ops are
supported, and on many platforms the quantized code may actually be slower than
the float equivalents, but this is a way of increasing performance substantially
when all the circumstances are right.
A full guide to optimizing for quantization is beyond the scope of this guide,
but one thing that can help is using the FakeQuantWithMinMaxVars op after Conv2D
or similar operations during training. This trains the min/max variables that
control the range used for quantization, so that the range doesn't have to be
calculated dynamically by RequantizationRange during inference.
## Transform Reference
The transforms string is parsed as a series of transform names, each of which
can have multiple named arguments inside parentheses. Arguments are separated by
commas, and double-quotes (") can be used to hold argument values if they
themselves contain commas (for example shape definitions).
The --inputs and --outputs are shared across all transforms, since it's common
to need to know what the ingoing and outgoing nodes in the graph are. You should
make sure you set these correctly before calling the graph transform tool, and
if you're in doubt check with the model's author, or use the `check_graph` tool
to examine likely inputs and outputs.
All transforms can be passed the `ignore_errors` flag, with the value set to
either true or false. By default any errors that happen within a transform will
abort the whole process, but if you enable this then an error will just be
logged and the transform skipped. This is especially useful for optional
transforms where version errors or other unimportant problems may trigger an
error.
### fold_batch_norms
Args: None
This transform tries to optimize away the Mul that's introduced after a Conv2D
when batch normalization has been used during training. It scans the graph for
any channel-wise multiplies immediately after convolutions, and multiplies the
convolution's weights with the mul instead so it can be omitted. You'll need to
make sure you run [fold_constants](#fold_constants) first, since the pattern can
only be spotted if the normal complex expression that's produced by training for
the Mul input is collapsed down into a simple constant.
### fold_constants
Args: None
Looks for any sub-graphs within the model that always evaluate to constant
expressions, and replaces them with those constants. This optimization is always
executed at run-time after the graph is loaded, so running it offline first
won't help latency, but it can simplify the graph and so make further processing
easier. It's often useful to call this with `fold_constants(ignore_errors=true)`
to continue on past transient errors, since this is just an optimization phase.
### fold_old_batch_norms
Args: None
In the early days of TensorFlow, batch normalization was implemented using a
single monolithic `BatchNormWithGlobalNormalization` op. In modern versions,
adding batch normalization from Python will give you a series of smaller math
ops instead, to achieve the same effect without special-purpose code. If you
have a graph that uses the older-style, this transform will recognize and
optimize those ops for inference, in the same way that the
[fold_batch_norms](#fold_batch_norms) transform does for the new approach.
### fuse_convolutions
Args: None
For graphs that use ResizeBilinear or MirrorPad ops before convolutions,
typically to scale up in the later stages of an image style transfer model for
example, it can save on memory usage and latency to combine the spatial
transformations with the convolution's im2col patch generation. This transform
looks out for that particular pattern of ops and replaces them with a fused
version that combines the resizing and padding with the convolution.
### merge_duplicate_nodes
Args: None
If there are Const nodes with the same types and contents, or nodes with the
same inputs and attributes, this transform will merge them together. It can be
useful when you want to cut down the number of nodes in a graph that has a lot
of redundancy, and is always run as part of [quantize_nodes](#quantize_nodes)
since the processing there can introduce duplicates of constants that are used
in the quantize/dequantize process.
### obsfucate_names
Args: None
Replaces all node's names with short generated ids, other than the inputs and
outputs. This also updates all references within the graph so that the structure
is preserved. This can be useful if you want to shrink the file size, or if you
want to make it harder to understand the architecture of your model before
releasing it.
### quantize_nodes
Args:
* input_min: The lowest float value for any quantized placeholder inputs.
* input_max: The highest float value for any quantized placeholder inputs. If
both input_min and input_max are set, then any float placeholders in the
graph will be replaced with quantized versions, and consts will be created
to pass the range to subsequent operations.
* fallback_min: The lowest float value to use for requantizing activation
layers.
* fallback_max: The highest float value to use for requantizing activation
layers. If both fallback_min and fallback_max are set, then instead of using
RequantizationRange ops ro figure out the useful range dynamically when
converting the 32-bit output of ops like QuantizedConv2D and
QuantizedBiasAdd, hardwired consts with these values will be used instead.
This can help performance, if you know the range of your activation layers
ahead of time.
Replaces any calculation nodes with their eight-bit equivalents, and adds in
conversion layers to allow remaining float operations to interoperate. This is
one of the most complex transforms, and involves multiple passes and a lot of
rewriting. It's also still an active area of research, so results may vary
depending on the platform and operations you're using in your model. You should
run [quantize_weights](#quantize_weights) first to ensure your Const ops are in
eight-bit form.
### quantize_weights
Args: None
Converts any large (more than 15 element) float Const op into an eight-bit
equivalent, followed by a float conversion op so that the result is usable by
subsequent nodes. This is mostly useful for [shrinking file
sizes](#shrinking-file-size), but also helps with the more advanced
[quantize_nodes](#quantize_nodes) transform.
### remove_attribute
Args:
* attribute_name: Name of the attribute you want to remove.
* op_name: Optional name of a single op to restrict the removal to.
Deletes the given attribute from either all nodes, or just the one specified in
`op_name`. This can be a dangerous transform since it's easy to leave your graph
in an invalid state if you remove a required attribute. It can be useful in
special circumstances though.
### remove_device
Args: None
All ops can have a hardware device specified. This can be a problem when you're
loading a graph on a different system than the model was trained on, since some
specified devices may not be available. In order to work with graphs like these,
you can run this transform to wipe the slate clean and delete the device
specifier from all ops.
### remove_nodes
Args:
* op: The name of the op you want to remove. Can be repeated to remove
multiple ops.
This is a potentially dangerous transform that looks for single-input,
single-output ops with the given names, removes them from the graph, and rewires
all inputs that use to pull from them to pull from the preceding node instead.
This is most useful for getting rid of ops like `CheckNumerics` that are useful
during training but just complicate the graph and increase latency during
inference. It's dangerous because it's possible that removing some ops may
change the output of your graph, so make sure you check the overall accuracy
after using this.
### rename_attribute
Args:
* old_attribute_name: Current name of the attribute you want to rename.
* new_attribute_name: Name that you want the attribute to have now.
* op_name: If this is set, only change attributes for a given op type,
otherwise apply to all nodes with attribute names that match.
Changes the name of the given attribute. This is often useful for upgrading
graph files as op definitions change over versions, since the renaming is often
enough to deal with minor changes.
### rename_op
Args:
* old_op_name: Current name of the operation.
* new_op_name: Name to change to.
Finds all ops with the given name, and changes them to the new one. This can be
useful for version upgrading if the changes between ops are minor apart from the
name.
### round_weights
Args:
* num_steps: How many unique values to use in each buffer.
Rounds all float values in large Const ops (more than 15 elements) to the given
number of steps. The unique values are chosen per buffer by linearly allocating
between the largest and smallest values present. This is useful when you'll be
deploying on mobile, and you want a model that will compress effectively. See
[shrinking file size](#shrinking-file-size) for more details.
### sort_by_execution_order
Args: None
Arranges the nodes in the GraphDef in topological order, so that the inputs of
any given node are always earlier than the node itself. This is especially
useful when you're targeting a minimal inference engine, since you can just
execute the nodes in the given order knowing that the inputs will be computed
before they're needed.
### strip_unused_nodes
Args:
* type: Default type for any new Placeholder nodes generated, for example
int32, float, quint8.
* shape: Default shape for any new Placeholder nodes generated, as
comma-separated dimensions. For example shape="1,299,299,3". The double
quotes are important, since otherwise the commas will be taken as argument
separators.
* name: Identifier for the placeholder arguments.
* type_for_name: What type to use for the previously-given name.
* shape_for_name: What shape to use for the previously-given name.
Removes all nodes not used in calculated the layers given in `--outputs`, fed by
`--inputs`. This is often useful for removing training-only nodes like
save-and-restore or summary ops. It's also handy for solving the [missing kernel
errors problem](#fixing-missing-kernel-errors-on-mobile) when there are decode
or other ops you don't need in the inference path.
The biggest complication is that it sometimes has to create new Placeholder ops,
so there are options to control their characteristics. This will happen if you
bypass a DecodeJpeg op by specifying an input layer deeper in the network, for
example, so you can pass in a raw image array instead of an encoded string as an
input. The decode op will be removed, together with the Placeholder that fed it,
but a new Placeholder is needed for the input layer you specify. The type and
shape arguments let you control the attributes of any new Placeholders that are
created. Plain `type` and `shape` set global defaults, but if you have different
inputs with varying characteristics, you'll need to pass in a list of arguments
where the preceding name specifies what layer each applies to. For example, if
you had two inputs in1 and in2, you could call `strip_unused_node(name=in1,
type_for_name=int32, shape_for_name="2,3", name=in2, type_for_name=float,
shape_for_name="1,10,10,3")`.
## Writing Your Own Transforms
The Graph Transform Tool is designed to make it as easy as possible to create
your own optimization, modification, and pre-processing transforms. At their
heart, all of the transforms take in a valid GraphDef, make some changes, and
output a new GraphDef. Each GraphDef is just a list of NodeDefs, each defining
one node in the graph and its connections. You can find more information on the
format at [this guide to TensorFlow model
files](https://www.tensorflow.org/versions/master/how_tos/tool_developers/index.html),
but for a simple example take a look at
[tensorflow/tools/graph_transforms/rename_op.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/rename_op.cc),
which implements the [rename_op](#rename_op) transform:
```C++
Status RenameOp(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
if (!context.params.count("old_op_name") ||
(context.params.at("old_op_name").size() != 1) ||
!context.params.count("new_op_name") ||
(context.params.at("new_op_name").size() != 1)) {
return errors::InvalidArgument(
"remove_nodes expects exactly one 'old_op_name' and 'new_op_name' "
"argument, e.g. rename_op(old_op_name=Mul, new_op_name=Multiply)");
}
const string old_op_name = context.params.at("old_op_name")[0];
const string new_op_name = context.params.at("new_op_name")[0];
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
if (node.op() == old_op_name) {
new_node->set_op(new_op_name);
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("rename_op", RenameOp);
```
The heart of this transform is the loop through the input_graph_def's nodes. We
go through each op, add a new one to the output, copy the original's contents,
and then change the op over if it matches the parameters. There's a standard set
of parameters for every transform, so they all take in a GraphDef and context,
and write out into a new GraphDef. The registration macro at the bottom lets the
tool know what function to call when it finds the `rename_op` string in a
transforms list.
### Transform Functions
The standard signature that all transform functions have is defined as
`TransformFunc`, which takes in an input GraphDef, a `TransformFuncContext`
containing environment information, writes to an output GraphDef, and returns a
Status indicating whether the transform succeeded.
The `TransformFuncContext` has a list of the inputs and outputs for the graph,
and the [parameter arguments](#parameters) that were passed into the transform
by the user.
If you write a function that matches this signature, and [register
it](#registration), the graph transform tool will take care of calling it.
### Pattern Syntax
The `rename_op` example only needs to look at a single node at a time, but one
of the most common needs is to modify small sub-graphs within a model. To make
this easy, the Graph Transform Tool provides the `OpTypePattern` syntax. This is
a simple and compact way to specify patterns of nodes that you want to look for.
For example, if you want all Conv2D nodes that have a constant as their second
input, you would set up a pattern like this, using C++ initializer lists to
populate the structure:
```C++
OpTypePattern conv_pattern({"Conv2D", {{"*"}, {"Const"}}});
```
It can be easier to visualize these initializers using indentation to show the
tree structure more clearly:
```C++
OpTypePattern conv_pattern({
"Conv2D",
{
{"*"},
{"Const"}
}
});
```
In plain English this is saying, a Conv2D op with two inputs, the first of which
is any op type, and the second is a Const op.
The op field can either contain a single "*", which means match any op type, one
op type (for example "Const"), or a set of op types separated by `|` symbols
(for example "Conv2D|MatMul|BiasAdd"). General regex patterns are not supported,
just these special cases.
You can think of these patterns as very limited regular expressions designed to
pick out sub-trees in graphs. They are deliberately very constrained to the kind
of things we commonly find ourselves needing to do, to make creating and
debugging as straightforward as possible.
Here's a much more complex example, from the [quantize_nodes](#quantize_nodes)
transform:
```C++
{"QuantizeV2",
{
{"Dequantize"},
{"Min",
{
{"Reshape",
{
{"Dequantize"},
{"Const"},
}
},
{"Const"},
}
},
{"Max",
{
{"Reshape",
{
{"Dequantize"},
{"Const"},
}
},
{"Const"},
}
},
}
}
```
This is looking for QuantizeV2 nodes, with three inputs, the first of which is a
Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the
third is a Max which does the same. We know the end result of this sub-graph is
a no-op, since it's just turning an eight-bit buffer into float, and then
immediately converting it back to eight-bits, so if we look for this pattern and
remove it we can optimize the graph without changing the result.
### ReplaceMatchingOpTypes
It's very common to want to find all occurrences of a particular sub-graph in a
model, and replace them all with a different sub-graph that keeps the same local
input and output connections. For example with
[fuse_convolutions](#fuse_convolutions), we needed to find all Conv2D ops that
read their inputs from BilinearResizes, and replace those combinations with a
single FusedResizeAndPadConv2D op, but without affecting other ops.
To make that sort of transformation easy, we created the
`ReplaceMatchingOpTypes` helper. This takes in a graph, an `OpTypePattern`
defining the sub-graph to look for, and a callback function to run for every
occurrence it finds. The job of this callback function is to look at the
`NodeMatch` that contains information about the current sub-graph, and return a
new sub-graph in the new_nodes list that will be used to replace the old
sub-graph.
You can see how it's used in practice in the
[fuse_convolutions](#fuse_convolutions) code:
```C++
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Conv2D",
{
{"ResizeBilinear"},
{"*"}
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
// Find all the nodes we expect in the subgraph.
const NodeDef& conv_node = match.node;
const NodeDef& resize_node = match.inputs[0].node;
const NodeDef& weights_node = match.inputs[1].node;
// We'll be reusing the old weights.
new_nodes->push_back(weights_node);
// Create a 'no-op' mirror padding node that has no effect.
NodeDef pad_dims_node;
pad_dims_node.set_op("Const");
pad_dims_node.set_name(conv_node.name() + "_dummy_paddings");
SetNodeAttr("dtype", DT_INT32, &pad_dims_node);
SetNodeTensorAttr<int32>("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0},
&pad_dims_node);
new_nodes->push_back(pad_dims_node);
// Set up the new fused version of the convolution op.
NodeDef fused_conv;
fused_conv.set_op("FusedResizeAndPadConv2D");
fused_conv.set_name(match.node.name());
AddNodeInput(resize_node.input(0), &fused_conv);
AddNodeInput(resize_node.input(1), &fused_conv);
AddNodeInput(pad_dims_node.name(), &fused_conv);
AddNodeInput(conv_node.input(1), &fused_conv);
CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
&fused_conv);
SetNodeAttr("mode", "REFLECT", &fused_conv);
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
new_nodes->push_back(fused_conv);
return Status::OK();
},
{}, &replaced_graph_def));
```
Here you can see we define the pattern to look for, and in the callback function
use information from each of the nodes in the old sub-graph to create a new
fused node. We also copy over the old weights input node so that isn't lost.
There are a few things to know about the `ReplaceMatchingOpTypes` function:
* All of the nodes in any matching sub-graphs are removed from the new graph
created by the function. If any of them are needed, it's the callback
function's responsibility to add them back in. There's a `CopyOriginalMatch`
convenience call that will copy over all of the original nodes if you decide
you don't actually want to modify a particular sub-graph.
* Nodes will never appear in more than one matched sub-graph. This is to
ensure that sub-trees are only replaced once, but it may mean that some
sub-graphs aren't spotted if they overlap with earlier matches.
* The calling framework tries to ensure that the graph remains sane, by
looking at the new_nodes that are returned and making sure that no nodes
which are needed as inputs by nodes outside the sub-graph are removed. These
important nodes are listed in the `output_nodes` argument that's passed into
each replacement function call. You can disable this checking by setting
`allow_inconsistencies` to true in the options, but otherwise any
replacements that break the graph constraints will be cancelled. If you do
allow inconsistencies, it's your transform's responsibility to fix them up
before you return your final result. Functions like `RenameNodeInputs` can
be useful if you are doing wholesale node renaming for example.
### Parameters
The arguments that are in parentheses after the transform name when the tool is
called are parsed and placed into the params member of the TransformFuncContext
that's given to each transform. For every named argument, there's a vector of
strings containing all the values that it was given, in the order they were
given. These are treated a bit like command-line parameters, and it's the
transform's responsibility to parse them into the data types it needs, and raise
errors by returning a bad Status if any of them are ill-formed.
As an example, here's a hypothetical transform call:
```
some_transform(foo=a, foo=b, bar=2, bob="1,2,3")
```
Here's what the std::map of strings looks like in the params member:
```
{{"foo", {"a", "b"}}, {"bar", {"2"}}, {"bob", {"1,2,3"}}}
```
The double quotes around the comma-separated argument to `bob` are important
because otherwise they'll be treated as separate arguments, and the parsing will
fail.
Here's an example of how [round_weights](#round_weights) reads its `num_steps`
parameter:
```C++
string num_steps_string;
TF_RETURN_IF_ERROR(
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
int32 num_steps;
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
return errors::InvalidArgument(
"Couldn't interpret the num_steps argument to round_weights as a "
"number:",
num_steps_string);
}
```
Things to notice here are that you have to convert the string to an integer, and
if the conversion fails you need to raise a meaningful error through the status
result of the transform. We're also using a helper function which raises an
error if the parameter is present multiple times, and uses a default if the user
hasn't specified it.
### Function Libraries
A newer feature of TensorFlow is the ability to create libraries of functions as
part of graphs. These are a bit like templates, which define macro operations in
terms of smaller components, which can then be instantiated with different input
and output connections inside the graph just like regular ops. Right now the
graph transform tool just copies these libraries between the input and output
graphs, but it's likely that more complex operations will be supported on them
in the future.
### Registering
The Graph Transform Tool associates names of transforms with the code to
implement them using the `REGISTER_GRAPH_TRANSFORM()` macro. This takes a string
and a function, and automagically registers the transform with the tool. You
will need to watch out for a few things though:
* Because it's using global C++ objects in each file under the hood, the
linker can sometimes strip them out and lose the registration. In Bazel you
need to make sure you're linking any new transforms in as libraries, and use
the `alwayslink` flag in your `cc_binary` call.
* You should be able to create your own copy of the transform_graph tool by
linking against the transform_graph_main_lib library in
tensorflow/tools/graph_transforms/BUILD. This contains all the `main()`
logic to parse command line arguments and call transforms.

View File

@ -0,0 +1,108 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the
// Mul baked into the convolution weights, to save computation during inference.
Status FoldBatchNorms(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Mul", // mul_node
{
{"Conv2D", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
}
},
{"Const"}, // mul_values_node
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
// Find all the nodes we expect in the subgraph.
const NodeDef& mul_node = match.node;
const NodeDef& conv_node = match.inputs[0].node;
const NodeDef& input_node = match.inputs[0].inputs[0].node;
const NodeDef& weights_node = match.inputs[0].inputs[1].node;
const NodeDef& mul_values_node = match.inputs[1].node;
Tensor weights = GetNodeTensorAttr(weights_node, "value");
Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value");
// Make sure all the inputs really are vectors, with as many entries as
// there are columns in the weights.
const int64 weights_cols = weights.shape().dim_size(3);
if ((mul_values.shape().dims() != 1) ||
(mul_values.shape().dim_size(0) != weights_cols)) {
return errors::InvalidArgument(
"Mul constant input to batch norm has bad shape: ",
mul_values.shape().DebugString());
}
// Multiply the original weights by the scale vector.
auto weights_matrix = weights.flat_inner_dims<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
for (int64 col = 0; col < weights_cols; ++col) {
scaled_weights_matrix(row, col) =
weights_matrix(row, col) * mul_values.flat<float>()(col);
}
}
// Construct the new nodes.
NodeDef scaled_weights_node;
scaled_weights_node.set_op("Const");
scaled_weights_node.set_name(weights_node.name());
SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
new_nodes->push_back(scaled_weights_node);
new_nodes->push_back(input_node);
NodeDef new_conv_node;
new_conv_node.CopyFrom(conv_node);
new_conv_node.set_name(mul_node.name());
new_nodes->push_back(new_conv_node);
return Status::OK();
},
{}, &replaced_graph_def));
*output_graph_def = replaced_graph_def;
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("fold_batch_norms", FoldBatchNorms);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,93 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status FoldBatchNorms(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class FoldBatchNormsTest : public ::testing::Test {
protected:
void TestFoldBatchNorms() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
test::FillValues<float>(&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
{1, 1, 1, 1}, "VALID");
Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
Output mul_values_op = Const(root.WithOpName("mul_values"),
Input::Initializer(mul_values_data));
Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef fused_graph_def;
TF_ASSERT_OK(
FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("Mul", node.op());
}
}
};
TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -95,17 +95,16 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
} }
Status RemoveUnusedNodes(const GraphDef& input_graph_def, Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const std::vector<string>& inputs, const TransformFuncContext& context,
const std::vector<string>& outputs,
GraphDef* output_graph_def) { GraphDef* output_graph_def) {
std::map<string, const NodeDef*> node_map; std::map<string, const NodeDef*> node_map;
MapNamesToNodes(input_graph_def, &node_map); MapNamesToNodes(input_graph_def, &node_map);
std::map<string, bool> used_nodes; std::map<string, bool> used_nodes;
for (const string& input : inputs) { for (const string& input : context.input_names) {
used_nodes[input] = true; used_nodes[input] = true;
} }
std::vector<string> current_nodes = outputs; std::vector<string> current_nodes = context.output_names;
while (!current_nodes.empty()) { while (!current_nodes.empty()) {
std::vector<string> next_nodes; std::vector<string> next_nodes;
for (const string& node_name : current_nodes) { for (const string& node_name : current_nodes) {
@ -134,9 +133,10 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
return Status::OK(); return Status::OK();
} }
// Converts any sub-graphs that can be resolved into constant expressions into
// single Const ops.
Status FoldConstants(const GraphDef& input_graph_def, Status FoldConstants(const GraphDef& input_graph_def,
const std::vector<string>& inputs, const TransformFuncContext& context,
const std::vector<string>& outputs,
GraphDef* output_graph_def) { GraphDef* output_graph_def) {
// Some older GraphDefs have saved _output_shapes attributes which are out of // Some older GraphDefs have saved _output_shapes attributes which are out of
// date and cause import errors, so clean them up first. // date and cause import errors, so clean them up first.
@ -148,20 +148,24 @@ Status FoldConstants(const GraphDef& input_graph_def,
ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr)); ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
DeviceAttributes device_attributes; DeviceAttributes device_attributes;
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
&input_graph, inputs, outputs, {}, device_attributes)); &input_graph, context.input_names, context.output_names, {},
if (!DoConstantFolding(ConstantFoldingOptions(), nullptr, Env::Default(), device_attributes));
nullptr, &input_graph)) { bool was_mutated;
return errors::InvalidArgument("Constant folding failed"); TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
} ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,
&was_mutated));
GraphDef folded_graph_def; GraphDef folded_graph_def;
input_graph.ToGraphDef(&folded_graph_def); input_graph.ToGraphDef(&folded_graph_def);
GraphDef send_recvs_replaced; GraphDef send_recvs_replaced;
TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def, inputs, TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def,
outputs, &send_recvs_replaced)); context.input_names, context.output_names,
TF_RETURN_IF_ERROR(RemoveUnusedNodes(send_recvs_replaced, inputs, outputs, &send_recvs_replaced));
output_graph_def)); TF_RETURN_IF_ERROR(
RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def));
return Status::OK(); return Status::OK();
} }
REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow { namespace tensorflow {
namespace graph_transforms { namespace graph_transforms {
@ -27,15 +28,13 @@ namespace graph_transforms {
// the names of all the nodes that data is fed into, or read out of, when the // the names of all the nodes that data is fed into, or read out of, when the
// graph is actually run. // graph is actually run.
Status FoldConstants(const GraphDef& input_graph_def, Status FoldConstants(const GraphDef& input_graph_def,
const std::vector<string>& inputs, const TransformFuncContext& context,
const std::vector<string>& outputs,
GraphDef* output_graph_def); GraphDef* output_graph_def);
// Analyzes which nodes are used for the given set of inputs and outputs, and // Analyzes which nodes are used for the given set of inputs and outputs, and
// returns a copy of the graph with any that aren't used removed. // returns a copy of the graph with any that aren't used removed.
Status RemoveUnusedNodes(const GraphDef& input_graph_def, Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const std::vector<string>& inputs, const TransformFuncContext& context,
const std::vector<string>& outputs,
GraphDef* output_graph_def); GraphDef* output_graph_def);
} // namespace graph_transforms } // namespace graph_transforms

View File

@ -82,12 +82,13 @@ class ConstantFoldingTest : public ::testing::Test {
TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors)); TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors));
GraphDef folded_graph_def; GraphDef folded_graph_def;
std::vector<string> input_names; graph_transforms::TransformFuncContext context;
for (const std::pair<string, Tensor>& input : inputs) { for (const std::pair<string, Tensor>& input : inputs) {
input_names.push_back(input.first); context.input_names.push_back(input.first);
} }
TF_ASSERT_OK(graph_transforms::FoldConstants(graph_def, input_names, context.output_names = outputs;
outputs, &folded_graph_def)); TF_ASSERT_OK(
graph_transforms::FoldConstants(graph_def, context, &folded_graph_def));
std::unique_ptr<tensorflow::Session> folded_session( std::unique_ptr<tensorflow::Session> folded_session(
tensorflow::NewSession(tensorflow::SessionOptions())); tensorflow::NewSession(tensorflow::SessionOptions()));
@ -187,7 +188,7 @@ class ConstantFoldingTest : public ::testing::Test {
TF_ASSERT_OK(root.ToGraphDef(&graph_def)); TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef result_graph_def; GraphDef result_graph_def;
TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes( TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
graph_def, {"placeholder"}, {"output"}, &result_graph_def)); graph_def, {{"placeholder"}, {"output"}}, &result_graph_def));
std::map<string, const NodeDef*> node_map; std::map<string, const NodeDef*> node_map;
graph_transforms::MapNamesToNodes(result_graph_def, &node_map); graph_transforms::MapNamesToNodes(result_graph_def, &node_map);

View File

@ -1,110 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility that transforms a model with subgraphs that evaluate to constant
// functions into the equivalent model with those subgraphs replaced by Const
// nodes. This simplifies the graph, and makes some further transformations
// easier to perform. It's often useful to run the freeze_graph tool on the
// input graph beforehand to ensure variables have been transformed to Consts.
//
// bazel-bin/tensorflow/tools/graph_transforms/fold_constants_tool \
// --in_graph=graph_def.pb \
// --out_graph=folded_graph_def.pb \
// --inputs=input1,input2 \
// --outputs=output1,output2
//
// Parameters:
// in_graph - name of a file with a frozen GraphDef proto in binary format.
// out_graph - name of the output file to save the folded version to.
// inputs - layer names of the nodes that will be fed data.
// outputs - layer names of the nodes that will be read from after running.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
namespace tensorflow {
namespace {
int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
string in_graph = "";
string out_graph = "";
string inputs_string = "";
string outputs_string = "";
std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"),
Flag("out_graph", &out_graph, "output graph file name"),
Flag("inputs", &inputs_string, "inputs"),
Flag("outputs", &outputs_string, "outputs"),
};
string usage = Flags::Usage(argv[0], flag_list);
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
// We need to call this to set up global state for TensorFlow.
port::InitMain(argv[0], &argc, &argv);
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
if (in_graph.empty()) {
LOG(ERROR) << "in_graph graph can't be empty";
return -1;
}
if (out_graph.empty()) {
LOG(ERROR) << "out_graph graph can't be empty";
return -1;
}
std::vector<string> inputs = str_util::Split(inputs_string, ',');
std::vector<string> outputs = str_util::Split(outputs_string, ',');
GraphDef graph_def;
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
if (!load_status.ok()) {
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
<< load_status.error_message();
return -1;
}
GraphDef folded_graph_def;
Status folding_result = graph_transforms::FoldConstants(
graph_def, inputs, outputs, &folded_graph_def);
if (!folding_result.ok()) {
LOG(ERROR) << "Folding failed " << folding_result.error_message();
return -1;
}
Status save_status =
WriteBinaryProto(Env::Default(), out_graph, folded_graph_def);
if (!save_status.ok()) {
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
<< save_status.error_message();
return -1;
}
return 0;
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
return tensorflow::ParseFlagsAndConvertGraph(argc, argv);
}

View File

@ -0,0 +1,193 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
namespace {
// Ensures the tensor is the expected shape.
Status ErrorIfNotVector(const Tensor& input, const string& input_name,
int expected_width) {
if ((input.shape().dims() != 1) ||
(input.shape().dim_size(0) != expected_width)) {
return errors::InvalidArgument(input_name,
" input to batch norm has bad shape: ",
input.shape().DebugString());
}
return Status::OK();
}
} // namespace
// Finds monolithic batch norm ops (as used in early versions of TensorFlow) and
// converts them into premultiplied weight inputs to convolutions.
Status FoldOldBatchNorms(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef current_graph_def = input_graph_def;
// We have to do several passes to catch all the old BN nodes, since many of
// them may share inputs and so be excluded from replacement in one pass.
bool did_graph_change;
do {
did_graph_change = false;
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, // clang-format off
{"BatchNormWithGlobalNormalization", // batch_norm_node
{
{"Conv2D", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
}
},
{"Const"}, // mean_node
{"Const"}, // variance_node
{"Const"}, // beta_node
{"Const"}, // gamma_node
}
}, // clang-format on
[&did_graph_change](const NodeMatch& match,
const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
// Find all the nodes we expect in the subgraph.
const NodeDef& batch_norm_node = match.node;
CHECK_EQ("BatchNormWithGlobalNormalization", batch_norm_node.op());
const NodeDef& conv_node = match.inputs[0].node;
CHECK_EQ("Conv2D", conv_node.op());
const NodeDef& input_node = match.inputs[0].inputs[0].node;
const NodeDef& weights_node = match.inputs[0].inputs[1].node;
CHECK_EQ("Const", weights_node.op());
const NodeDef& mean_node = match.inputs[1].node;
CHECK_EQ("Const", mean_node.op());
const NodeDef& variance_node = match.inputs[2].node;
CHECK_EQ("Const", variance_node.op());
const NodeDef& beta_node = match.inputs[3].node;
CHECK_EQ("Const", beta_node.op());
const NodeDef& gamma_node = match.inputs[4].node;
CHECK_EQ("Const", gamma_node.op());
// We have a set of vectors that we want to combine into a vector of
// scale values to apply column-wise to the weight input to the conv,
// and an offset vector that we'll apply to the output of the conv.
Tensor weights = GetNodeTensorAttr(weights_node, "value");
Tensor mean = GetNodeTensorAttr(mean_node, "value");
Tensor variance = GetNodeTensorAttr(variance_node, "value");
Tensor beta = GetNodeTensorAttr(beta_node, "value");
Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
const float variance_epsilon =
batch_norm_node.attr().at("variance_epsilon").f();
const bool scale_after_normalization =
batch_norm_node.attr().at("scale_after_normalization").b();
// Make sure all the inputs really are vectors, with as many entries
// as there are columns in the weights.
const int64 weights_cols = weights.shape().dim_size(3);
TF_RETURN_IF_ERROR(ErrorIfNotVector(mean, "Mean", weights_cols));
TF_RETURN_IF_ERROR(
ErrorIfNotVector(variance, "Variance", weights_cols));
TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", weights_cols));
TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", weights_cols));
// Calculate the scale and offset values to apply.
std::vector<float> scale_values(weights_cols);
std::vector<float> offset_values(weights_cols);
if (scale_after_normalization) {
for (int i = 0; i < weights_cols; ++i) {
scale_values[i] =
(1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
gamma.flat<float>()(i);
offset_values[i] = 0.0f;
}
} else {
for (int i = 0; i < weights_cols; ++i) {
scale_values[i] =
(1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
offset_values[i] = (-mean.flat<float>()(i) * scale_values[i]) +
beta.flat<float>()(i);
}
}
// Multiply the original weights by the scale vector.
auto weights_matrix = weights.flat_inner_dims<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
for (int64 col = 0; col < weights_cols; ++col) {
scaled_weights_matrix(row, col) =
weights_matrix(row, col) * scale_values[col];
}
}
// Figure out the remaining bias to add on.
Tensor bias_offset(DT_FLOAT, {weights_cols});
auto bias_offset_vector = bias_offset.flat<float>();
for (int64 col = 0; col < weights_cols; ++col) {
bias_offset_vector(col) = offset_values[col];
}
// Construct the new nodes.
NodeDef scaled_weights_node;
scaled_weights_node.set_op("Const");
scaled_weights_node.set_name(weights_node.name());
SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
SetNodeTensorAttr<float>("value", scaled_weights,
&scaled_weights_node);
new_nodes->push_back(scaled_weights_node);
// The input and convolution can be copied straight over, since the
// name of the scaled weights constant is the same as the original.
new_nodes->push_back(input_node);
new_nodes->push_back(conv_node);
NodeDef bias_offset_node;
bias_offset_node.set_op("Const");
bias_offset_node.set_name(conv_node.name() + "_bn_offset");
SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
new_nodes->push_back(bias_offset_node);
NodeDef bias_add_node;
bias_add_node.set_op("BiasAdd");
bias_add_node.set_name(batch_norm_node.name());
CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
AddNodeInput(conv_node.name(), &bias_add_node);
AddNodeInput(bias_offset_node.name(), &bias_add_node);
new_nodes->push_back(bias_add_node);
did_graph_change = true;
return Status::OK();
},
{}, &replaced_graph_def));
current_graph_def = replaced_graph_def;
} while (did_graph_change);
*output_graph_def = current_graph_def;
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,128 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status FoldOldBatchNorms(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class FoldOldBatchNormsTest : public ::testing::Test {
protected:
void TestFoldOldBatchNorms() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
test::FillValues<float>(&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
{1, 1, 1, 1}, "VALID");
Tensor mean_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&mean_data, {10.0f, 20.0f});
Output mean_op =
Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
Tensor variance_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&variance_data, {0.25f, 0.5f});
Output variance_op = Const(root.WithOpName("variance_op"),
Input::Initializer(variance_data));
Tensor beta_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&beta_data, {0.1f, 0.6f});
Output beta_op =
Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
Tensor gamma_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
Output gamma_op =
Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
// This is needed because we're trying to convert over a deprecated op which
// should only be present in older GraphDef files. Without this we see a
// deprecation error.
// This is justified because we're trying to test a tool that is expected to
// run on legacy files, to help users convert over to less problematic
// versions.
NodeDef batch_norm_node;
batch_norm_node.set_op("BatchNormWithGlobalNormalization");
batch_norm_node.set_name("output");
AddNodeInput("conv_op", &batch_norm_node);
AddNodeInput("mean_op", &batch_norm_node);
AddNodeInput("variance_op", &batch_norm_node);
AddNodeInput("beta_op", &batch_norm_node);
AddNodeInput("gamma_op", &batch_norm_node);
SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
*(original_graph_def.mutable_node()->Add()) = batch_norm_node;
original_graph_def.mutable_versions()->set_producer(8);
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef fused_graph_def;
TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
&fused_graph_def));
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
}
}
};
TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
TestFoldOldBatchNorms();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,200 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
Status FuseResizePadAndConv(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Conv2D",
{
{"MirrorPad",
{
{"ResizeBilinear"},
{"*"}
}
},
{"*"}
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
// Find all the nodes we expect in the subgraph.
const NodeDef& conv_node = match.node;
const NodeDef& mirror_pad_node = match.inputs[0].node;
const NodeDef& weights_node = match.inputs[1].node;
const NodeDef& resize_node = match.inputs[0].inputs[0].node;
const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
// We'll be reusing the old weights and pad dimensions.
new_nodes->push_back(weights_node);
new_nodes->push_back(pad_dims_node);
// Set up the new fused version of the convolution op.
NodeDef fused_conv;
fused_conv.set_op("FusedResizeAndPadConv2D");
fused_conv.set_name(match.node.name());
AddNodeInput(resize_node.input(0), &fused_conv);
AddNodeInput(resize_node.input(1), &fused_conv);
AddNodeInput(mirror_pad_node.input(1), &fused_conv);
AddNodeInput(conv_node.input(1), &fused_conv);
CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
&fused_conv);
CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
new_nodes->push_back(fused_conv);
return Status::OK();
},
{}, &replaced_graph_def));
*output_graph_def = replaced_graph_def;
return Status::OK();
}
Status FuseResizeAndConv(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Conv2D",
{
{"ResizeBilinear"},
{"*"}
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
// Find all the nodes we expect in the subgraph.
const NodeDef& conv_node = match.node;
const NodeDef& resize_node = match.inputs[0].node;
const NodeDef& weights_node = match.inputs[1].node;
// We'll be reusing the old weights.
new_nodes->push_back(weights_node);
// Create a 'no-op' mirror padding node that has no effect.
NodeDef pad_dims_node;
pad_dims_node.set_op("Const");
pad_dims_node.set_name(conv_node.name() + "_dummy_paddings");
SetNodeAttr("dtype", DT_INT32, &pad_dims_node);
SetNodeTensorAttr<int32>("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0},
&pad_dims_node);
new_nodes->push_back(pad_dims_node);
// Set up the new fused version of the convolution op.
NodeDef fused_conv;
fused_conv.set_op("FusedResizeAndPadConv2D");
fused_conv.set_name(match.node.name());
AddNodeInput(resize_node.input(0), &fused_conv);
AddNodeInput(resize_node.input(1), &fused_conv);
AddNodeInput(pad_dims_node.name(), &fused_conv);
AddNodeInput(conv_node.input(1), &fused_conv);
CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
&fused_conv);
SetNodeAttr("mode", "REFLECT", &fused_conv);
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
new_nodes->push_back(fused_conv);
return Status::OK();
},
{}, &replaced_graph_def));
*output_graph_def = replaced_graph_def;
return Status::OK();
}
Status FusePadAndConv(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Conv2D",
{
{"MirrorPad",
{
{"*"},
{"*"},
}
},
{"*"}
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
// Find all the nodes we expect in the subgraph.
const NodeDef& conv_node = match.node;
CHECK_EQ("Conv2D", conv_node.op());
const NodeDef& mirror_pad_node = match.inputs[0].node;
CHECK_EQ("MirrorPad", mirror_pad_node.op());
const NodeDef& weights_node = match.inputs[1].node;
const NodeDef& input_node = match.inputs[0].inputs[0].node;
const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
// We'll be reusing the old weights and pad dimensions.
new_nodes->push_back(weights_node);
new_nodes->push_back(input_node);
new_nodes->push_back(pad_dims_node);
// Set up the new fused version of the convolution op.
NodeDef fused_conv;
fused_conv.set_op("FusedPadConv2D");
fused_conv.set_name(match.node.name());
AddNodeInput(mirror_pad_node.input(0), &fused_conv);
AddNodeInput(mirror_pad_node.input(1), &fused_conv);
AddNodeInput(conv_node.input(1), &fused_conv);
CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
new_nodes->push_back(fused_conv);
return Status::OK();
},
{}, &replaced_graph_def));
*output_graph_def = replaced_graph_def;
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("fuse_resize_pad_and_conv", FuseResizePadAndConv);
REGISTER_GRAPH_TRANSFORM("fuse_resize_and_conv", FuseResizeAndConv);
REGISTER_GRAPH_TRANSFORM("fuse_pad_and_conv", FusePadAndConv);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,212 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status FuseResizePadAndConv(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
Status FuseResizeAndConv(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
Status FusePadAndConv(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class FuseConvolutionsTest : public ::testing::Test {
protected:
void TestFuseResizePadAndConv() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op,
Const(root.WithOpName("size"), {12, 4}),
ResizeBilinear::AlignCorners(false));
Tensor pad_dims_data(DT_INT32, TensorShape({4, 2}));
test::FillValues<int32>(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0});
Output pad_dims_op = Const(root.WithOpName("pad_dims_op"),
Input::Initializer(pad_dims_data));
Output pad_op =
MirrorPad(root.WithOpName("pad_op"), resize_op, pad_dims_op, "REFLECT");
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
test::FillValues<float>(&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op,
{1, 1, 1, 1}, "VALID");
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef fused_graph_def;
TF_ASSERT_OK(FuseResizePadAndConv(original_graph_def, {{}, {"output"}},
&fused_graph_def));
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("Conv2D", node.op());
EXPECT_NE("MirrorPad", node.op());
EXPECT_NE("ResizeBilinear", node.op());
}
}
void TestFuseResizeAndConv() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op,
Const(root.WithOpName("size"), {12, 4}),
ResizeBilinear::AlignCorners(false));
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
test::FillValues<float>(&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("output"), resize_op, weights_op,
{1, 1, 1, 1}, "VALID");
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef fused_graph_def;
TF_ASSERT_OK(FuseResizeAndConv(original_graph_def, {{}, {"output"}},
&fused_graph_def));
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("Conv2D", node.op());
EXPECT_NE("ResizeBilinear", node.op());
}
}
void TestFusePadAndConv() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Tensor pad_dims_data(DT_INT32, TensorShape({4, 2}));
test::FillValues<int32>(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0});
Output pad_dims_op = Const(root.WithOpName("pad_dims_op"),
Input::Initializer(pad_dims_data));
Output pad_op =
MirrorPad(root.WithOpName("pad_op"), input_op, pad_dims_op, "REFLECT");
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
test::FillValues<float>(&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op,
{1, 1, 1, 1}, "VALID");
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef fused_graph_def;
TF_ASSERT_OK(
FusePadAndConv(original_graph_def, {{}, {"output"}}, &fused_graph_def));
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("Conv2D", node.op());
EXPECT_NE("MirrorPad", node.op());
}
}
};
TEST_F(FuseConvolutionsTest, TestFuseResizePadAndConv) {
TestFuseResizePadAndConv();
}
TEST_F(FuseConvolutionsTest, TestFuseResizeAndConv) { TestFuseResizeAndConv(); }
TEST_F(FuseConvolutionsTest, TestFusePadAndConv) { TestFusePadAndConv(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,104 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Renames all nodes not uses as graph inputs or outputs to short numerical
// forms.
Status ObsfucateNames(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
std::unordered_set<string> required_nodes;
for (const string& input : context.input_names) {
required_nodes.insert(input);
}
for (const string& output : context.output_names) {
required_nodes.insert(output);
}
for (const string& required_node : required_nodes) {
LOG(INFO) << "required_node=" << required_node;
}
const string valid_chars =
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
const int64 chars_size = valid_chars.size();
std::map<string, string> new_names;
int64 name_index = 0;
for (const NodeDef& input_node : input_graph_def.node()) {
const string& old_name = input_node.name();
string new_name;
if (required_nodes.count(old_name)) {
new_name = old_name;
} else {
do {
int64 remaining = name_index;
new_name = "";
while (true) {
const int64 remainder = (remaining % chars_size);
const char current_char = valid_chars[remainder];
new_name = current_char + new_name;
remaining /= chars_size;
if (remaining <= 0) {
break;
}
}
++name_index;
} while (required_nodes.count(new_name));
}
new_names[old_name] = new_name;
}
output_graph_def->Clear();
for (const NodeDef& input_node : input_graph_def.node()) {
NodeDef* node = output_graph_def->mutable_node()->Add();
node->CopyFrom(input_node);
const string& old_name = input_node.name();
node->set_name(new_names[old_name]);
node->mutable_input()->Clear();
for (const string& input_name : input_node.input()) {
string prefix;
string input_node_name;
string suffix;
NodeNamePartsFromInput(input_name, &prefix, &input_node_name, &suffix);
if (new_names.count(input_node_name) == 0) {
return errors::InvalidArgument("No node named ", input_node_name,
" for input to ", old_name);
}
string new_input_name = prefix + new_names[input_node_name] + suffix;
*(node->mutable_input()->Add()) = new_input_name;
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,142 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status ObsfucateNames(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class ObsfucateNamesTest : public ::testing::Test {
protected:
void TestSimpleTree() {
GraphDef graph_def;
NodeDef* add_node1 = graph_def.add_node();
add_node1->set_name("add_node1");
add_node1->set_op("Add");
add_node1->add_input("add_node2");
add_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node3");
add_node3->add_input("const_node4");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* const_node4 = graph_def.add_node();
const_node4->set_name("const_node4");
const_node4->set_op("Const");
GraphDef result;
TF_ASSERT_OK(
ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node1"));
EXPECT_EQ(0, node_lookup.count("add_node2"));
EXPECT_EQ(0, node_lookup.count("add_node3"));
EXPECT_EQ(1, node_lookup.count("const_node1"));
EXPECT_EQ(0, node_lookup.count("const_node2"));
EXPECT_EQ(0, node_lookup.count("const_node3"));
EXPECT_EQ(0, node_lookup.count("const_node4"));
}
void TestManyNodes() {
GraphDef graph_def;
for (int i = 0; i < 1000; ++i) {
NodeDef* const_node = graph_def.add_node();
const_node->set_name(strings::StrCat("const_node", i));
const_node->set_op("Const");
}
GraphDef result;
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}},
&result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("const_node0"));
EXPECT_EQ(0, node_lookup.count("const_node500"));
EXPECT_EQ(1, node_lookup.count("const_node999"));
}
void TestNameClashes() {
GraphDef graph_def;
for (int i = 0; i < 1000; ++i) {
NodeDef* const_node = graph_def.add_node();
const_node->set_name(strings::StrCat("1", i));
const_node->set_op("Const");
}
GraphDef result;
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("10"));
EXPECT_EQ(1, node_lookup.count("19"));
std::unordered_set<string> names;
for (const NodeDef& node : result.node()) {
EXPECT_EQ(0, names.count(node.name()))
<< "Found multiple nodes with name '" << node.name() << "'";
names.insert(node.name());
}
}
};
TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); }
TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); }
TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,922 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Holds the information we need to translate from a float version of this op
// into the quantized equivalent.
struct QuantizedOpInfo {
// The name of the float op.
string float_name;
// Which attributes to copy directly over.
std::vector<string> attrs_to_copy;
// Extra data type attributes we need to set.
std::vector<std::pair<string, DataType>> dtypes_to_set;
// What depth of inputs the op can read in.
DataType input_bit_depth;
// The depth of the op's quantized outputs.
DataType output_bit_depth;
// Which inputs (e.g. shapes) aren't involved in the quantization process.
std::set<int32> unquantized_inputs;
// How the outputs are arranged, either
// [input0, input1, min0, max0, min1, max1] for contiguous, or
// [input0, input1, min0, min1, max0, max1] for separate.
// The separate order is needed because it's the only way to specify unknown
// numbers of inputs for ops like Concat.
enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
};
// Every op that has a quantized equivalent should be listed here, so that the
// conversion process can transform them.
const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
static const std::vector<QuantizedOpInfo> op_list = {
{"AvgPool",
{"ksize", "strides", "padding"},
{{"T", DT_QUINT8}},
DT_QUINT8,
DT_QUINT8,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"BiasAdd",
{},
{{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
DT_QUINT8,
DT_QINT32,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"Concat",
{"N"},
{{"T", DT_QUINT8}},
DT_QUINT8,
DT_QUINT8,
{0},
QuantizedOpInfo::SEPARATE_MIN_MAX},
{"Conv2D",
{"strides", "padding"},
{{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
DT_QUINT8,
DT_QINT32,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"MatMul",
{"transpose_a", "transpose_b"},
{{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
DT_QUINT8,
DT_QINT32,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"MaxPool",
{"ksize", "strides", "padding"},
{{"T", DT_QUINT8}},
DT_QUINT8,
DT_QUINT8,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"Relu",
{},
{{"Tinput", DT_QUINT8}},
DT_QUINT8,
DT_QUINT8,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"Relu6",
{},
{{"Tinput", DT_QUINT8}},
DT_QUINT8,
DT_QUINT8,
{},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
{"Reshape",
{},
{{"T", DT_QUINT8}},
DT_QUINT8,
DT_QUINT8,
{1},
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
};
return op_list;
}
namespace {
// Replaces invalid characters in input names to get a unique node name.
string UniqueNodeNameFromInput(const string& input_name) {
string prefix;
string node_name;
string suffix;
NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
string result;
if (prefix == "^") {
result += "__hat__";
}
result += node_name;
if (suffix != "") {
result += "__port__" + suffix.substr(1, suffix.size() - 1);
}
return result;
}
// Pulls two float values from the named parameters, with a lot of checking.
Status ExtractRangeFromParams(const TransformFuncContext& context,
const string& min_name, const string& max_name,
float* min_value, float* max_value,
bool* has_range) {
// See if we've been given quantized inputs with a known range.
const bool has_min = (context.params.count(min_name) != 0);
const bool has_max = (context.params.count(max_name) != 0);
*has_range = (has_min || has_max);
if (!*has_range) {
return Status::OK();
}
if (!has_min || !has_max) {
return errors::InvalidArgument("You must pass both ", min_name, " and ",
max_name, " into quantize_nodes");
}
std::vector<string> min_strings = context.params.at(min_name);
std::vector<string> max_strings = context.params.at(max_name);
if ((min_strings.size() != 1) || (max_strings.size() != 1)) {
return errors::InvalidArgument("You must pass a single ", min_name,
" and single ", max_name,
" value into "
"quantize_nodes");
}
if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) {
return errors::InvalidArgument("Couldn't decode ", min_name,
" as a number: ", min_strings[0]);
}
if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) {
return errors::InvalidArgument("Couldn't decode ", max_name,
" as a number: ", max_strings[0]);
}
return Status::OK();
}
bool AreAttrsEqual(const NodeDef* current_node, const NodeDef* other_node) {
if (current_node->attr_size() != other_node->attr_size()) {
return false;
}
string current_serialized;
string other_serialized;
for (const auto& attr : other_node->attr()) {
auto iter = current_node->attr().find(attr.first);
if (iter == current_node->attr().end()) return false;
iter->second.SerializeToString(&current_serialized);
attr.second.SerializeToString(&other_serialized);
if (current_serialized != other_serialized) return false;
}
return true;
}
} // namespace
// Analyzes all the nodes in the graph to figure out which ones are duplicates
// apart from their names. This commonly includes identical Const nodes, but can
// also be simple operations that are repeated on multiple outputs of a
// particular node. The complexity is managed using a hash function that avoids
// the need for any O(n^2) algorithms when identifying duplicates.
Status MergeDuplicateNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
// Make sure we can look up inputs and outputs quickly.
std::set<string> input_names(context.input_names.begin(),
context.input_names.end());
std::set<string> output_names(context.output_names.begin(),
context.output_names.end());
GraphDef current_graph_def = input_graph_def;
// Keep running the merging until no more duplicates are found.
bool any_duplicates_found;
do {
any_duplicates_found = false;
// First arrange all of the nodes by a hash of their contents.
std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
for (const NodeDef& node : current_graph_def.node()) {
NodeDef nameless_node = node;
// The name matters if it's being used as an input or output node,
// otherwise ignore it when looking for duplicates.
if (!input_names.count(node.name()) && !output_names.count(node.name())) {
nameless_node.set_name("");
}
const uint64 hash = HashNodeDef(nameless_node);
hashed_nodes[hash].push_back(&node);
}
// If we have multiple nodes with the same hash, then we know they're
// duplicates and can be removed, unless they're stateful.
std::map<string, string> inputs_to_rename;
GraphDef merged_graph_def;
for (const std::pair<uint64, std::vector<const NodeDef*>> hashed_node_info :
hashed_nodes) {
const std::vector<const NodeDef*>& hash_node_list =
hashed_node_info.second;
for (int i = 0; i < hash_node_list.size(); ++i) {
const NodeDef* current_node = hash_node_list[i];
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(
OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
if (is_duplicate) {
const string original_name = hash_node_list[0]->name();
inputs_to_rename[current_node->name() + ":*"] = original_name;
any_duplicates_found = true;
} else {
NodeDef* new_node = merged_graph_def.mutable_node()->Add();
*new_node = *current_node;
}
}
}
// Update the graph so that any nodes that referred to removed inputs now
// pull from the remaining duplicate.
RenameNodeInputs(merged_graph_def, inputs_to_rename, &current_graph_def);
} while (any_duplicates_found);
*output_graph_def = current_graph_def;
return Status::OK();
}
// Looks for the patterns that indicate there are two eight-bit ops feeding into
// each other, separated by a conversion up to float and back again. These occur
// during the initial conversion of ops to their quantized forms. Because we're
// only looking at an individual op in that phase and don't know if its inputs
// and outputs are eight-bit-capable, we start by converting the actual op into
// quantized form, but add float conversions before and after. This pass gets
// rid of those conversions if it turns out we do have adjacent ops capable of
// eight-bit processing.
Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
std::set<string> graph_outputs;
for (const string& output_name : context.output_names) {
graph_outputs.insert(NodeNameFromInput(output_name));
}
std::map<string, string> inputs_to_rename;
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"QuantizeV2",
{
{"Dequantize"},
{"Min"},
{"Max"},
}
}, // clang-format on
[&inputs_to_rename, &graph_outputs](const NodeMatch& match,
const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& quantize_node = match.node;
const NodeDef& dequantize_node = match.inputs[0].node;
inputs_to_rename[quantize_node.name() + ":0"] =
dequantize_node.input(0);
inputs_to_rename[quantize_node.name() + ":1"] =
dequantize_node.input(1);
inputs_to_rename[quantize_node.name() + ":2"] =
dequantize_node.input(2);
// Are other sub-graphs using the float intermediate result? If so,
// preserve it, but the input renaming still rewires the eight-bit ops
// so they don't go through float.
if (output_nodes.count(dequantize_node.name()) ||
graph_outputs.count(dequantize_node.name())) {
CopyOriginalMatch(match, new_nodes);
}
return Status::OK();
},
{true}, &replaced_graph_def));
RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def);
return Status::OK();
}
// If the user has passed in the input_min and input_max args, then we need to
// convert any input placeholders from float to eight bit, so quantized inputs
// can be fed directly into the graph.
Status QuantizePlaceholders(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
float input_min;
float input_max;
bool has_input_range;
TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
&input_min, &input_max,
&has_input_range));
if (!has_input_range) {
*output_graph_def = input_graph_def;
return Status::OK();
}
std::map<string, string> inputs_to_rename_first_pass;
std::map<string, string> inputs_to_rename_second_pass;
GraphDef placeholder_graph_def;
placeholder_graph_def.Clear();
for (const NodeDef& node : input_graph_def.node()) {
if (node.op() != "Placeholder") {
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(node);
} else {
string namespace_prefix = node.name() + "_eightbit";
NodeDef quantized_placeholder;
quantized_placeholder.CopyFrom(node);
SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
(placeholder_graph_def.mutable_node()->Add())
->CopyFrom(quantized_placeholder);
NodeDef min_node;
min_node.set_op("Const");
min_node.set_name(namespace_prefix + "/min");
SetNodeAttr("dtype", DT_FLOAT, &min_node);
Tensor min_tensor(DT_FLOAT, {});
min_tensor.flat<float>()(0) = input_min;
SetNodeTensorAttr<float>("value", min_tensor, &min_node);
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(min_node);
NodeDef max_node;
max_node.set_op("Const");
max_node.set_name(namespace_prefix + "/max");
SetNodeAttr("dtype", DT_FLOAT, &max_node);
Tensor max_tensor(DT_FLOAT, {});
max_tensor.flat<float>()(0) = input_max;
SetNodeTensorAttr<float>("value", max_tensor, &max_node);
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(max_node);
const string rename_suffix = "__RENAMED_PLACEHOLDER__";
NodeDef dequantize_node;
dequantize_node.set_op("Dequantize");
dequantize_node.set_name(namespace_prefix + "/dequantize");
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
AddNodeInput(node.name() + rename_suffix, &dequantize_node);
AddNodeInput(min_node.name(), &dequantize_node);
AddNodeInput(max_node.name(), &dequantize_node);
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(dequantize_node);
// First make sure that any internal references to the old placeholder
// now point to the dequantize result.
inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
// Then fix up the dequantize op so that it really points to the
// placeholder.
inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
}
}
GraphDef first_pass_graph_def;
RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
&first_pass_graph_def);
RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
output_graph_def);
return Status::OK();
}
// During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
// an activation layer. To use these during inference, this pass converts those
// ops into Requantizes with the trained min/maxes as constant inputs.
Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"FakeQuantWithMinMaxVars",
{
{"*"},
{"Const"},
{"Const"},
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& fake_quant_node = match.node;
const NodeDef& original_op_node = match.inputs[0].node;
const NodeDef& fake_quant_min_node = match.inputs[1].node;
const NodeDef& fake_quant_max_node = match.inputs[2].node;
string namespace_prefix = fake_quant_node.name() + "_eightbit";
new_nodes->push_back(original_op_node);
new_nodes->push_back(fake_quant_min_node);
new_nodes->push_back(fake_quant_max_node);
NodeDef quantize_node;
quantize_node.set_op("QuantizeV2");
quantize_node.set_name(namespace_prefix + "/quantize");
SetNodeAttr("T", DT_QINT32, &quantize_node);
SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
AddNodeInput(fake_quant_node.input(0), &quantize_node);
AddNodeInput(fake_quant_min_node.name(), &quantize_node);
AddNodeInput(fake_quant_max_node.name(), &quantize_node);
new_nodes->push_back(quantize_node);
NodeDef requantize_node;
requantize_node.set_op("Requantize");
requantize_node.set_name(namespace_prefix + "/requantize");
SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
AddNodeInput(quantize_node.name() + ":0", &requantize_node);
AddNodeInput(quantize_node.name() + ":1", &requantize_node);
AddNodeInput(quantize_node.name() + ":2", &requantize_node);
AddNodeInput(fake_quant_min_node.name(), &requantize_node);
AddNodeInput(fake_quant_max_node.name(), &requantize_node);
new_nodes->push_back(requantize_node);
// Convert the 8-bit result back into float for the final output.
NodeDef dequantize_node;
dequantize_node.set_op("Dequantize");
dequantize_node.set_name(fake_quant_node.name());
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
new_nodes->push_back(dequantize_node);
return Status::OK();
},
{}, output_graph_def));
return Status::OK();
}
// We always generate Requantize ops driven by dynamic RequantizationRange
// calculations when we produce quantized ops like Conv2D or BiasAdd with
// 32-bit results. If there were FakeQuant ops already for those activation
// layers, then there will be a later Requantize op with constant min/max
// inputs, which is preferable for fast inference. This pass looks for those
// later Requantize ops, and replaces the dynamic version with them.
Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, // clang-format off
{"Requantize",
{
{"QuantizeV2",
{
{"Dequantize",
{
{"Requantize",
{
{"*"},
{"*"},
{"*"},
{"RequantizationRange"},
{"RequantizationRange"},
}
},
{"Requantize"},
{"Requantize"},
}
},
{"Const"},
{"Const"},
},
},
{"QuantizeV2"},
{"QuantizeV2"},
{"Const"},
{"Const"},
}
}, // clang-format on
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& fake_requantize_node = match.node;
const NodeDef& original_op_node =
match.inputs[0].inputs[0].inputs[0].inputs[0].node;
const NodeDef& fake_requantize_min_node = match.inputs[3].node;
const NodeDef& fake_requantize_max_node = match.inputs[4].node;
new_nodes->push_back(original_op_node);
new_nodes->push_back(fake_requantize_min_node);
new_nodes->push_back(fake_requantize_max_node);
NodeDef requantize_node;
requantize_node.CopyFrom(fake_requantize_node);
requantize_node.mutable_input()->Clear();
AddNodeInput(original_op_node.name() + ":0", &requantize_node);
AddNodeInput(original_op_node.name() + ":1", &requantize_node);
AddNodeInput(original_op_node.name() + ":2", &requantize_node);
AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
new_nodes->push_back(requantize_node);
return Status::OK();
},
{}, output_graph_def));
return Status::OK();
}
// Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
// linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
// op that we want to apply the trained constant conversions to. This pass tries
// to move FakeQuant ops up the input chain, so they're as close as possible to
// the 32-bit conversion, and so can be easily merged into the automatic dynamic
// Requantizes.
Status HoistFakeQuants(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef current_graph_def = input_graph_def;
const int max_depth = 3;
for (int depth = max_depth; depth > 0; --depth) {
OpTypePattern pattern = {"*"};
for (int i = 0; i < depth; ++i) {
pattern = {"*", {pattern}};
}
pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
GraphDef hoisted_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, pattern,
[depth](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& fake_quant_node = match.node;
const NodeDef& fake_quant_min_node = match.inputs[1].node;
const NodeDef& fake_quant_max_node = match.inputs[2].node;
std::vector<NodeDef> linear_nodes;
NodeMatch current_match = match;
for (int i = 0; i <= depth; ++i) {
linear_nodes.push_back(current_match.inputs[0].node);
current_match = current_match.inputs[0];
}
NodeDef new_fake_quant_node;
new_fake_quant_node.CopyFrom(fake_quant_node);
new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
new_fake_quant_node.set_input(
0, linear_nodes[linear_nodes.size() - 2].input(0));
new_nodes->push_back(new_fake_quant_node);
new_nodes->push_back(fake_quant_min_node);
new_nodes->push_back(fake_quant_max_node);
linear_nodes[linear_nodes.size() - 2].set_input(
0, new_fake_quant_node.name());
linear_nodes.front().set_name(fake_quant_node.name());
for (const NodeDef& linear_node : linear_nodes) {
new_nodes->push_back(linear_node);
}
return Status::OK();
},
{}, &hoisted_graph_def));
current_graph_def = hoisted_graph_def;
}
*output_graph_def = current_graph_def;
return Status::OK();
}
// Converts any float ops that have eight-bit equivalents into their quantized
// forms, so that as much calculation as possible is done in the lower-precision
// format.
Status QuantizeNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
// Loop through all of the quantizable op types, and replace any occurrences
// with equivalent sub-graphs with quantized ops at their core. For example
// this one-input operation:
//
// Input(float)
// |
// v
// Operation
// |
// v
// (float)
//
// Will be turned into it's quantized equivalent:
//
// Input(float) ReshapeDims
// +------v v-------------+
// | Reshape
// | |
// | | ReductionDims
// | +-----+ |
// | | +---c---------+
// | v v v v-------+
// | Min Max
// | +----+ |
// v v v--------+
// Quantize
// |
// v
// QuantizedOperation
// | | |
// v v v
// Dequantize
// |
// v
// (float)
//
// This keeps the inputs and outputs visible to the rest of the graph in
// float
// and converts them down to quantized buffers internally for the
// computation.
// The result will end up with a lot of redundant dequantize/quantize pairs
// between adjacent quantized ops, but a later pass removes these where it
// can.
const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
string op_pattern;
bool is_first = true;
std::map<string, QuantizedOpInfo> op_map;
for (const QuantizedOpInfo& op_info : op_list) {
strings::StrAppend(&op_pattern, (is_first ? "" : "|"), op_info.float_name);
op_map.insert({op_info.float_name, op_info});
is_first = false;
}
// If input_min and input max have been passed in, then we convert all float
// Placeholder nodes into quantized versions, with the supplied values as
// their range.
GraphDef placeholder_graph_def;
TF_RETURN_IF_ERROR(
QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
// If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
// operations like Relu or MaxPool, move them up so that they're as close as
// possible to ops with 32-bit outputs like BiasAdd or Conv2D.
GraphDef hoisted_graph_def;
TF_RETURN_IF_ERROR(
HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
// Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
// activation layers, into Requantize ops with those ranges instead. This
// makes it easier to replace the dynamic range calculations that are used
// by default.
GraphDef converted_graph_def;
TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
&converted_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
// If fallback_min and fallback_max are set, then we'll use hardwired ranges
// for all the 32-bit to 8-bit requantizations.
float fallback_min;
float fallback_max;
bool has_fallback_range;
TF_RETURN_IF_ERROR(ExtractRangeFromParams(
context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
&has_fallback_range));
// Replace all occurrences of the current float op with its quantized
// equivalent.
GraphDef quantized_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
converted_graph_def, {op_pattern},
[&op_map, fallback_min, fallback_max, has_fallback_range](
const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& float_node = match.node;
const QuantizedOpInfo& op_info = op_map[float_node.op()];
string namespace_prefix = float_node.name() + "_eightbit";
// Quantize all of the inputs.
std::vector<string> quantized_input_names;
for (int i = 0; i < float_node.input_size(); ++i) {
// Skip any non-float inputs.
if (op_info.unquantized_inputs.count(i)) {
continue;
}
const string& input_name = float_node.input(i);
string unique_input_name =
namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
// Add some common constants we need for reshaping inputs.
NodeDef reshape_dims;
reshape_dims.set_op("Const");
reshape_dims.set_name(unique_input_name + "/reshape_dims");
SetNodeAttr("dtype", DT_INT32, &reshape_dims);
Tensor reshape_dims_tensor(DT_INT32, {1});
reshape_dims_tensor.flat<int32>()(0) = -1;
SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
new_nodes->push_back(reshape_dims);
NodeDef reduction_dims;
reduction_dims.set_op("Const");
reduction_dims.set_name(unique_input_name + "/reduction_dims");
SetNodeAttr("dtype", DT_INT32, &reduction_dims);
Tensor reduction_dims_tensor(DT_INT32, {1});
reduction_dims_tensor.flat<int32>()(0) = 0;
SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
&reduction_dims);
new_nodes->push_back(reduction_dims);
NodeDef reshape_node;
reshape_node.set_op("Reshape");
reshape_node.set_name(unique_input_name + "/reshape");
SetNodeAttr("T", DT_FLOAT, &reshape_node);
AddNodeInput(input_name, &reshape_node);
AddNodeInput(reshape_dims.name(), &reshape_node);
new_nodes->push_back(reshape_node);
NodeDef min_node;
min_node.set_op("Min");
min_node.set_name(unique_input_name + "/min");
SetNodeAttr("T", DT_FLOAT, &min_node);
SetNodeAttr("keep_dims", false, &min_node);
AddNodeInput(reshape_node.name(), &min_node);
AddNodeInput(reduction_dims.name(), &min_node);
new_nodes->push_back(min_node);
NodeDef max_node;
max_node.set_op("Max");
max_node.set_name(unique_input_name + "/max");
SetNodeAttr("T", DT_FLOAT, &max_node);
SetNodeAttr("keep_dims", false, &max_node);
AddNodeInput(reshape_node.name(), &max_node);
AddNodeInput(reduction_dims.name(), &max_node);
new_nodes->push_back(max_node);
NodeDef quantize_node;
quantize_node.set_op("QuantizeV2");
quantize_node.set_name(unique_input_name + "/quantize");
SetNodeAttr("T", DT_QUINT8, &quantize_node);
SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
AddNodeInput(input_name, &quantize_node);
AddNodeInput(min_node.name(), &quantize_node);
AddNodeInput(max_node.name(), &quantize_node);
new_nodes->push_back(quantize_node);
quantized_input_names.push_back(quantize_node.name());
}
// Set up the quantized version of the current op.
NodeDef quantized_main_node;
quantized_main_node.set_op("Quantized" + float_node.op());
quantized_main_node.set_name(float_node.name() + "/eightbit");
for (const string& attr_to_copy : op_info.attrs_to_copy) {
CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
&quantized_main_node);
}
for (const std::pair<string, DataType>& dtype_to_set :
op_info.dtypes_to_set) {
SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
&quantized_main_node);
}
int quantized_input_index = 0;
for (int i = 0; i < float_node.input_size(); ++i) {
if (op_info.unquantized_inputs.count(i)) {
AddNodeInput(float_node.input(i), &quantized_main_node);
} else {
const string& quantized_input_name =
quantized_input_names[quantized_input_index];
AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
++quantized_input_index;
}
}
if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
for (const string& quantized_input_name : quantized_input_names) {
AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
}
} else {
for (const string& quantized_input_name : quantized_input_names) {
AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
}
for (const string& quantized_input_name : quantized_input_names) {
AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
}
}
new_nodes->push_back(quantized_main_node);
string eight_bit_node_name;
if (op_info.output_bit_depth == DT_QINT32) {
// Shrink the range of the output down from 32 bits to 8.
string requantize_min_input;
string requantize_max_input;
if (has_fallback_range) {
// Use constant values for the min/max range if they were given.
NodeDef fallback_min_node;
fallback_min_node.set_op("Const");
fallback_min_node.set_name(quantized_main_node.name() +
"/fallback_min");
SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
Tensor fallback_min_tensor(DT_FLOAT, {});
fallback_min_tensor.flat<float>()(0) = fallback_min;
SetNodeTensorAttr<float>("value", fallback_min_tensor,
&fallback_min_node);
new_nodes->push_back(fallback_min_node);
NodeDef fallback_max_node;
fallback_max_node.set_op("Const");
fallback_max_node.set_name(quantized_main_node.name() +
"/fallback_max");
SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
Tensor fallback_max_tensor(DT_FLOAT, {});
fallback_max_tensor.flat<float>()(0) = fallback_max;
SetNodeTensorAttr<float>("value", fallback_max_tensor,
&fallback_max_node);
new_nodes->push_back(fallback_max_node);
requantize_min_input = fallback_min_node.name();
requantize_max_input = fallback_max_node.name();
} else {
// Otherwise dynamically measure the range each time.
NodeDef requant_range_node;
requant_range_node.set_op("RequantizationRange");
requant_range_node.set_name(quantized_main_node.name() +
"/requant_range");
SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
AddNodeInput(quantized_main_node.name() + ":0",
&requant_range_node);
AddNodeInput(quantized_main_node.name() + ":1",
&requant_range_node);
AddNodeInput(quantized_main_node.name() + ":2",
&requant_range_node);
new_nodes->push_back(requant_range_node);
requantize_min_input = requant_range_node.name() + ":0";
requantize_max_input = requant_range_node.name() + ":1";
}
NodeDef requantize_node;
requantize_node.set_op("Requantize");
requantize_node.set_name(quantized_main_node.name() + "/requantize");
SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
AddNodeInput(requantize_min_input, &requantize_node);
AddNodeInput(requantize_max_input, &requantize_node);
new_nodes->push_back(requantize_node);
eight_bit_node_name = requantize_node.name();
} else {
eight_bit_node_name = quantized_main_node.name();
}
// Convert the 8-bit result back into float for the final output.
NodeDef dequantize_node;
dequantize_node.set_op("Dequantize");
dequantize_node.set_name(float_node.name());
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
new_nodes->push_back(dequantize_node);
return Status::OK();
},
{}, &quantized_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
// If we've ended up with two Requantize ops in a row (for example if there
// was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
// using the trained range from the second op.
GraphDef merged_graph_def;
TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
&merged_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
// There can be duplicate quantize nodes if multiple ops pull from a single
// input, which makes it harder to remove redundant ones, so strip them out.
GraphDef deduped_graph_def;
TF_RETURN_IF_ERROR(
MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
// Look for Dequantizes that immediately go into Quantizes, and remove them
// since the two together cancel each other out. This allows us to keep the
// data flow in eight bit where two adjacent ops are in eight bit, but still
// keep interoperability with float ops.
TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
output_graph_def));
TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
} // namespace graph_transforms
} // namespace tensorflow

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,139 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Converts any large float constants into eight-bit equivalents, with a
// Dequantize op so that subsequent nodes can still access the results in a
// float form.
Status QuantizeWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"Const"},
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& old_const_node = match.node;
if (!old_const_node.attr().count("dtype")) {
return errors::InvalidArgument("No 'dtype' attribute for Const node ",
old_const_node.name());
}
if (!old_const_node.attr().count("value")) {
return errors::InvalidArgument("No 'value' attribute for Const node ",
old_const_node.name());
}
const DataType old_dtype = old_const_node.attr().at("dtype").type();
Tensor old_tensor;
if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) {
return errors::InvalidArgument("Decoding Tensor failed for node",
old_const_node.name());
}
const size_t num_elements = old_tensor.NumElements();
// If this isn't a float constant, or it's too small, then reuse the
// same node with no changes.
if ((old_dtype != DT_FLOAT) || (num_elements < 16)) {
new_nodes->push_back(old_const_node);
return Status::OK();
}
const float* old_values = old_tensor.flat<float>().data();
float min = std::numeric_limits<float>::max();
float max = std::numeric_limits<float>::min();
for (int i = 0; i < num_elements; ++i) {
const float value = old_values[i];
min = std::min(min, value);
max = std::max(max, value);
}
// min_value == max_value is a tricky case. It can occur for general
// tensors, and of course for scalars. The quantized ops cannot deal
// with this case, so we set max_value to something else.
// It's a tricky question what is the numerically best solution to
// deal with this degeneracy.
// TODO(petewarden): Better use a tolerance than a hard comparison?
if (min == max) {
if (std::abs(min) < 0.000001f) {
max = min + 1.0f;
} else if (min > 0) {
max = 2.0f * min;
} else {
max = min / 2.0f;
}
}
Tensor quantized_tensor(DT_QUINT8, old_tensor.shape());
FloatTensorToQuantizedInPlace<quint8>(old_tensor, min, max,
&quantized_tensor);
NodeDef quantized_const_node;
quantized_const_node.set_op("Const");
quantized_const_node.set_name(old_const_node.name() +
"_quantized_const");
SetNodeAttr("dtype", DT_QUINT8, &quantized_const_node);
SetNodeTensorAttr<float>("value", quantized_tensor,
&quantized_const_node);
new_nodes->push_back(quantized_const_node);
NodeDef min_node;
min_node.set_op("Const");
min_node.set_name(old_const_node.name() + "_quantized_min");
SetNodeAttr("dtype", DT_FLOAT, &min_node);
Tensor min_tensor(DT_FLOAT, {});
min_tensor.scalar<float>()() = min;
SetNodeTensorAttr<float>("value", min_tensor, &min_node);
new_nodes->push_back(min_node);
NodeDef max_node;
max_node.set_op("Const");
max_node.set_name(old_const_node.name() + "_quantized_max");
SetNodeAttr("dtype", DT_FLOAT, &max_node);
Tensor max_tensor(DT_FLOAT, {});
max_tensor.scalar<float>()() = max;
SetNodeTensorAttr<float>("value", max_tensor, &max_node);
new_nodes->push_back(max_node);
NodeDef dequantize_node;
dequantize_node.set_op("Dequantize");
dequantize_node.set_name(old_const_node.name());
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
AddNodeInput(quantized_const_node.name(), &dequantize_node);
AddNodeInput(min_node.name(), &dequantize_node);
AddNodeInput(max_node.name(), &dequantize_node);
new_nodes->push_back(dequantize_node);
return Status::OK();
},
{}, output_graph_def));
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("quantize_weights", QuantizeWeights);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,103 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status QuantizeWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class QuantizeWeightsTest : public ::testing::Test {
protected:
void TestQuantizeWeights() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10}));
test::FillValues<float>(
&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op,
{1, 1, 1, 1}, "VALID");
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef quantized_graph_def;
TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}},
&quantized_graph_def));
std::unique_ptr<Session> quantized_session(NewSession(SessionOptions()));
TF_ASSERT_OK(quantized_session->Create(quantized_graph_def));
std::vector<Tensor> quantized_outputs;
TF_ASSERT_OK(
quantized_session->Run({}, {"output"}, {}, &quantized_outputs));
test::ExpectTensorNear<float>(original_outputs[0], quantized_outputs[0],
0.5);
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(quantized_graph_def, &node_lookup);
EXPECT_EQ(1, node_lookup.count("input_op"));
const NodeDef* q_input_op = node_lookup.at("input_op");
EXPECT_EQ(DT_FLOAT, q_input_op->attr().at("dtype").type());
EXPECT_EQ(1, node_lookup.count("weights_op"));
const NodeDef* q_weights_op = node_lookup.at("weights_op");
EXPECT_EQ("Dequantize", q_weights_op->op());
const string& weights_const_name =
NodeNameFromInput(q_weights_op->input(0));
EXPECT_EQ(1, node_lookup.count(weights_const_name));
const NodeDef* q_weights_const = node_lookup.at(weights_const_name);
EXPECT_EQ("Const", q_weights_const->op());
EXPECT_EQ(DT_QUINT8, q_weights_const->attr().at("dtype").type());
}
};
TEST_F(QuantizeWeightsTest, TestQuantizeWeights) { TestQuantizeWeights(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Deletes a given attribute from the specified nodes.
Status RemoveAttribute(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
if (!context.params.count("attribute_name") ||
(context.params.at("attribute_name").size() != 1)) {
return errors::InvalidArgument(
"remove_nodes expects exactly one 'attribute_name' "
"argument, e.g. remove_attribute(op_name=Mul, attribute_name=foo)");
}
string op_name;
if (context.params.count("op_name")) {
if (context.params.at("op_name").size() != 1) {
return errors::InvalidArgument(
"remove_nodes expects a single op_name argument, but found ",
context.params.at("op_name").size());
}
op_name = context.params.at("op_name")[0];
} else {
op_name = "*";
}
const string attribute_name = context.params.at("attribute_name")[0];
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
if (((op_name == "*") || (op_name == node.op())) &&
(node.attr().count(attribute_name))) {
new_node->mutable_attr()->erase(attribute_name);
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("remove_attribute", RemoveAttribute);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,123 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status RemoveAttribute(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class RemoveAttributeTest : public ::testing::Test {
protected:
void TestRemoveAttribute() {
GraphDef graph_def;
NodeDef* mul_node1 = graph_def.add_node();
mul_node1->set_name("mul_node1");
mul_node1->set_op("Mul");
mul_node1->add_input("add_node2");
mul_node1->add_input("add_node3");
SetNodeAttr<int32>("foo", 23, mul_node1);
SetNodeAttr<string>("bar", "something", mul_node1);
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
SetNodeAttr<int32>("foo", 46, add_node2);
SetNodeAttr<int32>("bob", 23, add_node2);
SetNodeAttr<string>("bar", "something else", add_node2);
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
GraphDef wildcard_result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"mul_node1"};
context.params.insert(
std::pair<string, std::vector<string>>({"op_name", {string("*")}}));
context.params.insert(std::pair<string, std::vector<string>>(
{"attribute_name", {string("foo")}}));
TF_ASSERT_OK(RemoveAttribute(graph_def, context, &wildcard_result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(wildcard_result, &node_lookup);
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
GraphDef targeted_result;
TransformFuncContext targeted_context;
targeted_context.input_names = {};
targeted_context.output_names = {"mul_node1"};
targeted_context.params.insert(
std::pair<string, std::vector<string>>({"op_name", {string("Mul")}}));
targeted_context.params.insert(std::pair<string, std::vector<string>>(
{"attribute_name", {string("foo")}}));
TF_ASSERT_OK(
RemoveAttribute(graph_def, targeted_context, &targeted_result));
MapNamesToNodes(targeted_result, &node_lookup);
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
}
};
TEST_F(RemoveAttributeTest, TestRemoveAttribute) { TestRemoveAttribute(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,47 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Clears the device field of all ops in the graph.
Status RemoveDevice(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
new_node->set_device("");
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("remove_device", RemoveDevice);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,95 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status RemoveDevice(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class RemoveDeviceTest : public ::testing::Test {
protected:
void TestRemoveDevice() {
GraphDef graph_def;
NodeDef* mul_node1 = graph_def.add_node();
mul_node1->set_name("mul_node1");
mul_node1->set_op("Mul");
mul_node1->set_device("//cpu:0");
mul_node1->add_input("add_node2");
mul_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
add_node2->set_device("//gpu:1");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"mul_node1"};
TF_ASSERT_OK(RemoveDevice(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ("", node_lookup.at("mul_node1")->device());
EXPECT_EQ("", node_lookup.at("add_node2")->device());
}
};
TEST_F(RemoveDeviceTest, TestRemoveDevice) { TestRemoveDevice(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,94 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Deletes any specified types of nodes, unless they're necessary for the
// graph's inputs or outputs.
Status RemoveNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
if (!context.params.count("op")) {
return errors::InvalidArgument(
"remove_nodes expects at least one 'op'"
"argument, e.g. remove_nodes(op=Identity)");
}
// Make sure we don't get rid of any nodes used as graph inputs or outputs.
std::set<string> required_nodes;
for (const string& input : context.input_names) {
required_nodes.insert(NodeNameFromInput(input));
}
for (const string& output : context.output_names) {
required_nodes.insert(NodeNameFromInput(output));
}
std::vector<string> ops_to_remove = context.params.at("op");
GraphDef current_graph_def = input_graph_def;
for (const string& op : ops_to_remove) {
// Keep looking for nodes to remove until there are no more changes.
bool any_nodes_removed;
do {
any_nodes_removed = false;
std::map<string, string> inputs_to_rename;
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, {op, {{"*"}}},
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& replace_node = match.node;
// If this node is needed in the inputs or outputs don't replace it.
if (required_nodes.count(replace_node.name())) {
LOG(INFO) << "Skipping replacement for " << replace_node.name();
CopyOriginalMatch(match, new_nodes);
return Status::OK();
}
const NodeDef& input_node = match.inputs[0].node;
inputs_to_rename[replace_node.name()] = input_node.name();
inputs_to_rename["^" + replace_node.name()] =
"^" + input_node.name();
new_nodes->push_back(input_node);
any_nodes_removed = true;
return Status::OK();
},
{true}, &replaced_graph_def));
// Make sure all references to removed nodes now point to their inputs.
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
&current_graph_def);
} while (any_nodes_removed);
}
*output_graph_def = current_graph_def;
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("remove_nodes", RemoveNodes);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,222 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status RemoveNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class RemoveNodesTest : public ::testing::Test {
protected:
void TestRemoveNodes() {
GraphDef graph_def;
NodeDef* add_node1 = graph_def.add_node();
add_node1->set_name("add_node1");
add_node1->set_op("Add");
add_node1->add_input("add_node2");
add_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("identity_node1");
add_node2->add_input("identity_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("identity_node1");
add_node3->add_input("const_node3");
NodeDef* identity_node1 = graph_def.add_node();
identity_node1->set_name("identity_node1");
identity_node1->set_op("Identity");
identity_node1->add_input("const_node1");
NodeDef* identity_node2 = graph_def.add_node();
identity_node2->set_name("identity_node2");
identity_node2->set_op("Identity");
identity_node2->add_input("const_node2");
NodeDef* identity_node3 = graph_def.add_node();
identity_node3->set_name("identity_node3");
identity_node3->set_op("Identity");
identity_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"add_node1"};
context.params.insert(
std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node1"));
EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0));
EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1));
EXPECT_EQ(1, node_lookup.count("add_node2"));
EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0));
EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1));
EXPECT_EQ(1, node_lookup.count("add_node3"));
EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0));
EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1));
EXPECT_EQ(1, node_lookup.count("add_node4"));
EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0));
EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1));
EXPECT_EQ(0, node_lookup.count("identity_node1"));
EXPECT_EQ(0, node_lookup.count("identity_node2"));
EXPECT_EQ(0, node_lookup.count("identity_node3"));
EXPECT_EQ(1, node_lookup.count("const_node1"));
EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
EXPECT_EQ(1, node_lookup.count("const_node2"));
EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
EXPECT_EQ(1, node_lookup.count("const_node3"));
EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
}
void TestRemoveOutputNodes() {
GraphDef graph_def;
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("const_node1");
add_node->add_input("const_node2");
NodeDef* identity_node = graph_def.add_node();
identity_node->set_name("identity_node");
identity_node->set_op("Identity");
identity_node->add_input("add_node");
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"identity_node"};
context.params.insert(
std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
EXPECT_EQ(1, node_lookup.count("identity_node"));
EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0));
}
void TestRemoveChainedNodes() {
GraphDef graph_def;
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* identity_node1 = graph_def.add_node();
identity_node1->set_name("identity_node1");
identity_node1->set_op("Identity");
identity_node1->add_input("const_node1");
NodeDef* identity_node2 = graph_def.add_node();
identity_node2->set_name("identity_node2");
identity_node2->set_op("Identity");
identity_node2->add_input("identity_node1");
NodeDef* identity_node3 = graph_def.add_node();
identity_node3->set_name("identity_node3");
identity_node3->set_op("Identity");
identity_node3->add_input("identity_node2");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("identity_node3");
add_node->add_input("const_node2");
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"identity_node"};
context.params.insert(
std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
EXPECT_EQ(0, node_lookup.count("identity_node1"));
EXPECT_EQ(0, node_lookup.count("identity_node2"));
EXPECT_EQ(0, node_lookup.count("identity_node3"));
}
};
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,70 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
Status RenameAttribute(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
if (!context.params.count("old_attribute_name") ||
(context.params.at("old_attribute_name").size() != 1) ||
!context.params.count("new_attribute_name") ||
(context.params.at("new_attribute_name").size() != 1)) {
return errors::InvalidArgument(
"remove_nodes expects exactly one 'old_attribute_name' and one "
"'new_attribute_name' argument, e.g. "
"remove_attribute(old_attribute_name=foo, new_attribute_name=bar)");
}
string op_name;
if (context.params.count("op_name")) {
op_name = context.params.at("op_name")[0];
} else {
op_name = "*";
}
const string old_attribute_name = context.params.at("old_attribute_name")[0];
const string new_attribute_name = context.params.at("new_attribute_name")[0];
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
if (((op_name == "*") || (op_name == node.op())) &&
(node.attr().count(old_attribute_name))) {
AttrValue attribute_value = node.attr().at(old_attribute_name);
new_node->mutable_attr()->erase(old_attribute_name);
new_node->mutable_attr()->insert({new_attribute_name, attribute_value});
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("rename_attribute", RenameAttribute);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,131 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status RenameAttribute(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class RenameAttributeTest : public ::testing::Test {
protected:
void TestRenameAttribute() {
GraphDef graph_def;
NodeDef* mul_node1 = graph_def.add_node();
mul_node1->set_name("mul_node1");
mul_node1->set_op("Mul");
mul_node1->add_input("add_node2");
mul_node1->add_input("add_node3");
AddNodeAttr<int32>("foo", 23, mul_node1);
AddNodeAttr<string>("bar", "something", mul_node1);
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
AddNodeAttr<int32>("foo", 46, add_node2);
AddNodeAttr<int32>("bob", 23, add_node2);
AddNodeAttr<string>("bar", "something else", add_node2);
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
GraphDef wildcard_result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"mul_node1"};
context.params.insert(
std::pair<string, std::vector<string>>({"op_name", {string("*")}}));
context.params.insert(std::pair<string, std::vector<string>>(
{"old_attribute_name", {string("foo")}}));
context.params.insert(std::pair<string, std::vector<string>>(
{"new_attribute_name", {string("baz")}}));
TF_ASSERT_OK(RenameAttribute(graph_def, context, &wildcard_result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(wildcard_result, &node_lookup);
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("baz"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
GraphDef targeted_result;
TransformFuncContext targeted_context;
targeted_context.input_names = {};
targeted_context.output_names = {"mul_node1"};
targeted_context.params.insert(
std::pair<string, std::vector<string>>({"op_name", {string("Mul")}}));
targeted_context.params.insert(std::pair<string, std::vector<string>>(
{"old_attribute_name", {string("foo")}}));
targeted_context.params.insert(std::pair<string, std::vector<string>>(
{"new_attribute_name", {string("baz")}}));
TF_ASSERT_OK(
RenameAttribute(graph_def, targeted_context, &targeted_result));
MapNamesToNodes(targeted_result, &node_lookup);
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo"));
EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("baz"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
}
};
TEST_F(RenameAttributeTest, TestRenameAttribute) { TestRenameAttribute(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,60 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Changes the op type of a specified op.
Status RenameOp(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
if (!context.params.count("old_op_name") ||
(context.params.at("old_op_name").size() != 1) ||
!context.params.count("new_op_name") ||
(context.params.at("new_op_name").size() != 1)) {
return errors::InvalidArgument(
"remove_nodes expects exactly one 'old_op_name' and 'new_op_name' "
"argument, e.g. rename_op(old_op_name=Mul, new_op_name=Multiply)");
}
const string old_op_name = context.params.at("old_op_name")[0];
const string new_op_name = context.params.at("new_op_name")[0];
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
if (node.op() == old_op_name) {
new_node->set_op(new_op_name);
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("rename_op", RenameOp);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,109 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status RenameOp(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class RenameOpTest : public ::testing::Test {
protected:
void TestRenameOp() {
GraphDef graph_def;
NodeDef* mul_node1 = graph_def.add_node();
mul_node1->set_name("mul_node1");
mul_node1->set_op("Mul");
mul_node1->add_input("add_node2");
mul_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"mul_node1"};
context.params.insert(std::pair<string, std::vector<string>>(
{"old_op_name", {string("Mul")}}));
context.params.insert(std::pair<string, std::vector<string>>(
{"new_op_name", {string("Multiply")}}));
TF_ASSERT_OK(RenameOp(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("mul_node1"));
EXPECT_EQ("Multiply", node_lookup.at("mul_node1")->op());
EXPECT_EQ(1, node_lookup.count("add_node2"));
EXPECT_EQ("Add", node_lookup.at("add_node2")->op());
EXPECT_EQ(1, node_lookup.count("add_node3"));
EXPECT_EQ("Add", node_lookup.at("add_node3")->op());
EXPECT_EQ(1, node_lookup.count("add_node4"));
EXPECT_EQ("Add", node_lookup.at("add_node4")->op());
EXPECT_EQ(1, node_lookup.count("const_node1"));
EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
EXPECT_EQ(1, node_lookup.count("const_node2"));
EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
EXPECT_EQ(1, node_lookup.count("const_node3"));
EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
}
};
TEST_F(RenameOpTest, TestRenameOp) { TestRenameOp(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,123 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Rounds any large float constants to the specified number of levels.
Status RoundWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
string num_steps_string;
TF_RETURN_IF_ERROR(
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
int32 num_steps;
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
return errors::InvalidArgument(
"Couldn't interpret the num_steps argument to round_weights as a "
"number:",
num_steps_string);
}
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"Const"},
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& old_const_node = match.node;
if (!old_const_node.attr().count("dtype")) {
return errors::InvalidArgument("No 'dtype' attribute for Const node ",
old_const_node.name());
}
if (!old_const_node.attr().count("value")) {
return errors::InvalidArgument("No 'value' attribute for Const node ",
old_const_node.name());
}
const DataType old_dtype = old_const_node.attr().at("dtype").type();
Tensor old_tensor;
if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) {
return errors::InvalidArgument("Decoding Tensor failed for node",
old_const_node.name());
}
const size_t num_elements = old_tensor.NumElements();
// If this isn't a float constant, or it's too small, then reuse the
// same node with no changes. The size is important because small
// constants tend to be used for more accuracy-sensitive calculations,
// and the benefit of shrinking them is very marginal.
if ((old_dtype != DT_FLOAT) || (num_elements < 16)) {
new_nodes->push_back(old_const_node);
return Status::OK();
}
const float* old_values = old_tensor.flat<float>().data();
float min = std::numeric_limits<float>::max();
float max = std::numeric_limits<float>::min();
for (int i = 0; i < num_elements; ++i) {
const float value = old_values[i];
min = std::min(min, value);
max = std::max(max, value);
}
// min_value == max_value is a tricky case. It can occur for general
// tensors, and of course for scalars. The quantized ops cannot deal
// with this case, so we set max_value to something else.
// It's a tricky question what is the numerically best solution to
// deal with this degeneracy.
// TODO(petewarden): Better use a tolerance than a hard comparison?
if (min == max) {
if (std::abs(min) < 0.000001f) {
max = min + 1.0f;
} else if (min > 0) {
max = 2.0f * min;
} else {
min = 2.0f * max;
}
}
Tensor rounded_tensor(DT_FLOAT, old_tensor.shape());
float* rounded_values = rounded_tensor.flat<float>().data();
const float bucket_width = (max - min) / num_steps;
for (int i = 0; i < num_elements; ++i) {
const int32 bucket = std::floor((old_values[i] - min) / bucket_width);
rounded_values[i] = min + (bucket_width * (bucket + 0.5f));
}
NodeDef rounded_const_node;
rounded_const_node.set_op("Const");
rounded_const_node.set_name(old_const_node.name());
SetNodeAttr("dtype", DT_FLOAT, &rounded_const_node);
SetNodeTensorAttr<float>("value", rounded_tensor, &rounded_const_node);
new_nodes->push_back(rounded_const_node);
return Status::OK();
},
{}, output_graph_def));
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("round_weights", RoundWeights);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,96 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status RoundWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class RoundWeightsTest : public ::testing::Test {
protected:
void TestRoundWeights() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
-5.0f, -3.0f, -6.0f});
Output input_op =
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10}));
test::FillValues<float>(
&weights_data,
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
Output weights_op =
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op,
{1, 1, 1, 1}, "VALID");
GraphDef original_graph_def;
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
TF_ASSERT_OK(original_session->Create(original_graph_def));
std::vector<Tensor> original_outputs;
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
GraphDef rounded_graph_def;
TF_ASSERT_OK(
RoundWeights(original_graph_def, {{}, {"output"}}, &rounded_graph_def));
std::unique_ptr<Session> rounded_session(NewSession(SessionOptions()));
TF_ASSERT_OK(rounded_session->Create(rounded_graph_def));
std::vector<Tensor> rounded_outputs;
TF_ASSERT_OK(rounded_session->Run({}, {"output"}, {}, &rounded_outputs));
test::ExpectTensorNear<float>(original_outputs[0], rounded_outputs[0], 0.5);
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(rounded_graph_def, &node_lookup);
EXPECT_EQ(1, node_lookup.count("input_op"));
const NodeDef* r_input_op = node_lookup.at("input_op");
EXPECT_EQ(DT_FLOAT, r_input_op->attr().at("dtype").type());
EXPECT_EQ(1, node_lookup.count("weights_op"));
const NodeDef* r_weights_op = node_lookup.at("weights_op");
EXPECT_EQ("Const", r_weights_op->op());
EXPECT_EQ(DT_FLOAT, r_weights_op->attr().at("dtype").type());
}
};
TEST_F(RoundWeightsTest, TestRoundWeights) { TestRoundWeights(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,43 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// This is a thin wrapper with the standard TransformFunc interface to the
// underlying utility function. The only difference is that we don't use the
// input or output name arguments.
Status SortByExecutionOrderWithUnusedContext(
const GraphDef& input_graph_def, const TransformFuncContext& unused_context,
GraphDef* output_graph_def) {
return SortByExecutionOrder(input_graph_def, output_graph_def);
}
REGISTER_GRAPH_TRANSFORM("sort_by_execution_order",
SortByExecutionOrderWithUnusedContext);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,206 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
class SortByExecutionOrderTest : public ::testing::Test {
protected:
void GetOrder(const GraphDef& graph_def, std::map<string, int>* order) {
for (int i = 0; i < graph_def.node_size(); ++i) {
const NodeDef& node = graph_def.node(i);
(*order)[node.name()] = i;
}
}
void TestSimpleAdd() {
GraphDef graph_def;
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("a_node");
add_node->add_input("b_node");
NodeDef* b_node = graph_def.add_node();
b_node->set_name("b_node");
b_node->set_op("Const");
NodeDef* a_node = graph_def.add_node();
a_node->set_name("a_node");
a_node->set_op("Const");
GraphDef result;
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
std::map<string, int> order;
GetOrder(result, &order);
EXPECT_EQ(2, order["add_node"]);
EXPECT_GT(2, order["a_node"]);
EXPECT_GT(2, order["b_node"]);
}
void TestSimpleLinear() {
GraphDef graph_def;
NodeDef* negative_node = graph_def.add_node();
negative_node->set_name("negative_node");
negative_node->set_op("Negative");
negative_node->add_input("sqrt_node");
NodeDef* relu_node = graph_def.add_node();
relu_node->set_name("relu_node");
relu_node->set_op("Relu");
relu_node->add_input("const_node");
NodeDef* sqrt_node = graph_def.add_node();
sqrt_node->set_name("sqrt_node");
sqrt_node->set_op("Sqrt");
sqrt_node->add_input("relu_node");
NodeDef* const_node = graph_def.add_node();
const_node->set_name("const_node");
const_node->set_op("Const");
GraphDef result;
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
std::map<string, int> order;
GetOrder(result, &order);
EXPECT_EQ(3, order["negative_node"]);
EXPECT_EQ(2, order["sqrt_node"]);
EXPECT_EQ(1, order["relu_node"]);
EXPECT_EQ(0, order["const_node"]);
}
void TestSimpleTree() {
GraphDef graph_def;
NodeDef* add_node1 = graph_def.add_node();
add_node1->set_name("add_node1");
add_node1->set_op("Add");
add_node1->add_input("add_node2");
add_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node3");
add_node3->add_input("const_node4");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* const_node4 = graph_def.add_node();
const_node4->set_name("const_node4");
const_node4->set_op("Const");
GraphDef result;
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
std::map<string, int> order;
GetOrder(result, &order);
EXPECT_EQ(6, order["add_node1"]);
EXPECT_GT(6, order["add_node2"]);
EXPECT_GT(6, order["add_node3"]);
EXPECT_GT(5, order["const_node1"]);
EXPECT_GT(5, order["const_node2"]);
EXPECT_GT(5, order["const_node3"]);
EXPECT_GT(5, order["const_node4"]);
}
void TestCommonAncestor() {
GraphDef graph_def;
NodeDef* add_node1 = graph_def.add_node();
add_node1->set_name("add_node1");
add_node1->set_op("Add");
add_node1->add_input("add_node2");
add_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
GraphDef result;
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
std::map<string, int> order;
GetOrder(result, &order);
EXPECT_EQ(5, order["add_node1"]);
EXPECT_GT(5, order["add_node2"]);
EXPECT_GT(5, order["add_node3"]);
EXPECT_GT(4, order["const_node2"]);
EXPECT_GT(4, order["const_node3"]);
EXPECT_GT(3, order["const_node1"]);
}
};
TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); }
TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); }
TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); }
TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,215 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
namespace {
Status TypeForPlaceholder(const TransformFuncContext& context,
const string& node_name, DataType* result) {
// If we don't find anything else, return float.
*result = DT_FLOAT;
// Check to see if we have been given a default for all placeholders.
if (context.params.count("type")) {
if (context.params.at("type").size() != 1) {
return errors::InvalidArgument(
"You must pass no more than one default 'type' to "
"strip_unused_nodes");
}
const string& type_string = context.params.at("type")[0];
if (!DataTypeFromString(type_string, result)) {
return errors::InvalidArgument("Couldn't understand type argument '",
type_string, "'");
}
}
// See if there's a particular type specified for this placeholder.
if (context.params.count("name") || context.params.count("type_for_name")) {
if (!context.params.count("name") ||
!context.params.count("type_for_name") ||
(context.params.at("type_for_name").size() !=
context.params.at("name").size())) {
return errors::InvalidArgument(
"You must pass a 'type_for_name' arg for every 'name', e.g. "
"strip_unused_nodes(name=foo, type_for_name=float, name=bar, "
"type_for_name=quint8");
}
const int name_count = context.params.at("name").size();
for (int i = 0; i < name_count; ++i) {
if (context.params.at("name")[i] == node_name) {
const string& type_string = context.params.at("type_for_name")[i];
if (!DataTypeFromString(type_string, result)) {
return errors::InvalidArgument("Couldn't understand type argument '",
type_string, "'");
}
}
}
}
return Status::OK();
}
// Takes a comma-separated string of numbers and parses them into a shape.
bool TensorShapeFromString(const string& shape_string, TensorShape* result) {
if (shape_string == "") {
return false;
}
std::vector<int64> dims;
if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
return false;
}
*result = TensorShape(dims);
return true;
}
Status ShapeForPlaceholder(const TransformFuncContext& context,
const string& node_name, TensorShape* result) {
// If we don't find anything else, return scalar.
*result = {};
// Check to see if we have been given a default for all placeholders.
if (context.params.count("type")) {
if (context.params.at("shape").size() != 1) {
return errors::InvalidArgument(
"You must pass no more than one default 'shape' to "
"strip_unused_nodes");
}
const string& shape_string = context.params.at("shape")[0];
if (!TensorShapeFromString(shape_string, result)) {
return errors::InvalidArgument("Couldn't understand shape argument '",
shape_string, "'");
}
}
// See if there's a particular type specified for this placeholder.
if (context.params.count("name") || context.params.count("type_for_name")) {
if (!context.params.count("name") ||
!context.params.count("type_for_name") ||
(context.params.at("type_for_name").size() !=
context.params.at("name").size())) {
return errors::InvalidArgument(
"You must pass a 'shape_for_name' arg for every 'name', e.g. "
"strip_unused_nodes(name=foo, shape_for_name=\"2,2,1\", name=bar, "
"shape_for_name=\"1\"");
}
const int name_count = context.params.at("name").size();
for (int i = 0; i < name_count; ++i) {
if (context.params.at("name")[i] == node_name) {
const string& shape_string = context.params.at("shape_for_name")[i];
if (!TensorShapeFromString(shape_string, result)) {
return errors::InvalidArgument("Couldn't understand shape argument '",
shape_string, "'");
}
}
}
}
return Status::OK();
}
} // namespace
// Delete any nodes that don't contribute to the inference result.
Status StripUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
std::set<string> required_nodes;
std::set<string> input_nodes;
for (const string& input : context.input_names) {
required_nodes.insert(NodeNameFromInput(input));
input_nodes.insert(NodeNameFromInput(input));
}
for (const string& output : context.output_names) {
required_nodes.insert(output);
}
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(input_graph_def, &node_lookup);
std::vector<string> current_inputs;
for (const string& output_name : context.output_names) {
current_inputs.push_back(NodeNameFromInput(output_name));
}
while (!current_inputs.empty()) {
std::set<string> next_inputs;
for (const string& current_input : current_inputs) {
required_nodes.insert(current_input);
if (input_nodes.count(current_input)) {
continue;
}
if (!node_lookup.count(current_input)) {
return errors::InvalidArgument("Input node ", current_input,
" not found in graph");
}
const NodeDef* current_node = node_lookup[current_input];
for (const string& input_name : current_node->input()) {
string input_node_name = NodeNameFromInput(input_name);
if (!required_nodes.count(input_node_name)) {
next_inputs.insert(input_node_name);
}
}
}
current_inputs =
std::vector<string>(next_inputs.begin(), next_inputs.end());
}
GraphDef filtered_graph_def;
FilterGraphDef(input_graph_def,
[&](const NodeDef& node) {
return required_nodes.count(node.name()) > 0;
},
&filtered_graph_def);
output_graph_def->Clear();
for (const NodeDef& node : filtered_graph_def.node()) {
if (input_nodes.count(node.name())) {
NodeDef placeholder_node;
if (node.op() == "Placeholder") {
placeholder_node.CopyFrom(node);
} else {
placeholder_node.set_op("Placeholder");
placeholder_node.set_name(node.name());
DataType type;
TF_RETURN_IF_ERROR(TypeForPlaceholder(context, node.name(), &type));
TensorShape shape;
TF_RETURN_IF_ERROR(ShapeForPlaceholder(context, node.name(), &shape));
SetNodeAttr("dtype", type, &placeholder_node);
SetNodeAttr("shape", shape, &placeholder_node);
}
*(output_graph_def->mutable_node()->Add()) = placeholder_node;
} else {
*(output_graph_def->mutable_node()->Add()) = node;
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("strip_unused_nodes", StripUnusedNodes);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,286 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status StripUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class StripUnusedNodesTest : public ::testing::Test {
protected:
void TestSimpleAdd() {
GraphDef graph_def;
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("a_node");
add_node->add_input("b_node");
NodeDef* a_node = graph_def.add_node();
a_node->set_name("a_node");
a_node->set_op("Const");
NodeDef* b_node = graph_def.add_node();
b_node->set_name("b_node");
b_node->set_op("Const");
NodeDef* c_node = graph_def.add_node();
c_node->set_name("c_node");
c_node->set_op("Const");
GraphDef result;
TF_ASSERT_OK(StripUnusedNodes(graph_def, {{}, {"add_node"}}, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ(1, node_lookup.count("a_node"));
EXPECT_EQ(1, node_lookup.count("b_node"));
EXPECT_EQ(0, node_lookup.count("c_node"));
}
void TestCommonAncestor() {
GraphDef graph_def;
NodeDef* add_node1 = graph_def.add_node();
add_node1->set_name("add_node1");
add_node1->set_op("Add");
add_node1->add_input("add_node2");
add_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* dangling_input = graph_def.add_node();
dangling_input->set_name("dangling_input");
dangling_input->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
GraphDef result;
TF_ASSERT_OK(StripUnusedNodes(
graph_def, {{"dangling_input"}, {"add_node1"}}, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node1"));
EXPECT_EQ(1, node_lookup.count("add_node2"));
EXPECT_EQ(1, node_lookup.count("add_node3"));
EXPECT_EQ(0, node_lookup.count("add_node4"));
EXPECT_EQ(1, node_lookup.count("const_node1"));
EXPECT_EQ(1, node_lookup.count("const_node2"));
EXPECT_EQ(1, node_lookup.count("const_node3"));
EXPECT_EQ(0, node_lookup.count("const_node4"));
EXPECT_EQ(1, node_lookup.count("dangling_input"));
}
void TestSimplePlaceholder() {
GraphDef graph_def;
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("mul_node");
add_node->add_input("a_node");
NodeDef* mul_node = graph_def.add_node();
mul_node->set_name("mul_node");
mul_node->set_op("Mul");
mul_node->add_input("b_node");
mul_node->add_input("c_node");
NodeDef* a_node = graph_def.add_node();
a_node->set_name("a_node");
a_node->set_op("Const");
NodeDef* b_node = graph_def.add_node();
b_node->set_name("b_node");
b_node->set_op("Const");
NodeDef* c_node = graph_def.add_node();
c_node->set_name("c_node");
c_node->set_op("Const");
GraphDef result;
TF_ASSERT_OK(
StripUnusedNodes(graph_def, {{"mul_node"}, {"add_node"}}, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ(1, node_lookup.count("mul_node"));
EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
EXPECT_EQ(DT_FLOAT, node_lookup["mul_node"]->attr().at("dtype").type());
EXPECT_EQ(TensorShape({}),
TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
EXPECT_EQ(1, node_lookup.count("a_node"));
EXPECT_EQ(0, node_lookup.count("b_node"));
EXPECT_EQ(0, node_lookup.count("c_node"));
}
void TestPlaceholderDefaultArgs() {
GraphDef graph_def;
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("mul_node");
add_node->add_input("a_node");
NodeDef* mul_node = graph_def.add_node();
mul_node->set_name("mul_node");
mul_node->set_op("Mul");
mul_node->add_input("b_node");
mul_node->add_input("c_node");
NodeDef* a_node = graph_def.add_node();
a_node->set_name("a_node");
a_node->set_op("Const");
NodeDef* b_node = graph_def.add_node();
b_node->set_name("b_node");
b_node->set_op("Const");
NodeDef* c_node = graph_def.add_node();
c_node->set_name("c_node");
c_node->set_op("Const");
GraphDef result;
TF_ASSERT_OK(StripUnusedNodes(graph_def,
{{"mul_node"},
{"add_node"},
{{"type", {"int32"}}, {"shape", {"1,2,3"}}}},
&result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ(1, node_lookup.count("mul_node"));
EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
EXPECT_EQ(DT_INT32, node_lookup["mul_node"]->attr().at("dtype").type());
EXPECT_EQ(TensorShape({1, 2, 3}),
TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
EXPECT_EQ(1, node_lookup.count("a_node"));
EXPECT_EQ(0, node_lookup.count("b_node"));
EXPECT_EQ(0, node_lookup.count("c_node"));
}
void TestPlaceholderNamedArgs() {
GraphDef graph_def;
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("mul_node");
add_node->add_input("a_node");
NodeDef* mul_node = graph_def.add_node();
mul_node->set_name("mul_node");
mul_node->set_op("Mul");
mul_node->add_input("b_node");
mul_node->add_input("c_node");
NodeDef* a_node = graph_def.add_node();
a_node->set_name("a_node");
a_node->set_op("Const");
NodeDef* b_node = graph_def.add_node();
b_node->set_name("b_node");
b_node->set_op("Const");
NodeDef* c_node = graph_def.add_node();
c_node->set_name("c_node");
c_node->set_op("Const");
GraphDef result;
TF_ASSERT_OK(StripUnusedNodes(graph_def,
{{"mul_node", "a_node"},
{"add_node"},
{{"name", {"a_node", "mul_node"}},
{"type_for_name", {"int64", "quint8"}},
{"shape_for_name", {"1,2", "1, 2, 3"}}}},
&result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ(1, node_lookup.count("mul_node"));
EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
EXPECT_EQ(DT_QUINT8, node_lookup["mul_node"]->attr().at("dtype").type());
EXPECT_EQ(TensorShape({1, 2, 3}),
TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
EXPECT_EQ(1, node_lookup.count("a_node"));
EXPECT_EQ("Placeholder", node_lookup["a_node"]->op());
EXPECT_EQ(DT_INT64, node_lookup["a_node"]->attr().at("dtype").type());
EXPECT_EQ(TensorShape({1, 2}),
TensorShape(node_lookup["a_node"]->attr().at("shape").shape()));
EXPECT_EQ(0, node_lookup.count("b_node"));
EXPECT_EQ(0, node_lookup.count("c_node"));
}
};
TEST_F(StripUnusedNodesTest, TestSimpleAdd) { TestSimpleAdd(); }
TEST_F(StripUnusedNodesTest, TestCommonAncestor) { TestCommonAncestor(); }
TEST_F(StripUnusedNodesTest, TestSimplePlaceholder) { TestSimplePlaceholder(); }
TEST_F(StripUnusedNodesTest, TestPlaceholderDefaultArgs) {
TestPlaceholderDefaultArgs();
}
TEST_F(StripUnusedNodesTest, TestPlaceholderNamedArgs) {
TestPlaceholderNamedArgs();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,205 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This program prints out a summary of a GraphDef file's contents, listing
// things that are useful for debugging and reusing the model it contains. For
// example it looks at the graph structure and op types to figure out likely
// input and output nodes, and shows which ops are used by the graph. To use it,
// run something like this:
//
// bazel build tensorflow/tools/graph_transforms:summarize_graph
// bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
// --in_graph=my_graph.pb
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
namespace {
Status SummarizeGraph(const GraphDef& graph) {
std::vector<const NodeDef*> placeholders;
for (const NodeDef& node : graph.node()) {
if (node.op() == "Placeholder") {
placeholders.push_back(&node);
}
}
if (placeholders.empty()) {
std::cout << "No inputs spotted." << std::endl;
} else {
std::cout << "Found " << placeholders.size() << " possible inputs: ";
for (const NodeDef* node : placeholders) {
TensorShape shape;
if (node->attr().count("shape")) {
TensorShapeProto shape_proto = node->attr().at("shape").shape();
shape = TensorShape(shape_proto);
}
DataType dtype = node->attr().at("dtype").type();
std::cout << "(name=" << node->name();
std::cout << ", type=" << DataTypeString(dtype) << "(" << dtype << ")";
std::cout << ", shape=" << shape.DebugString() << ") ";
}
std::cout << std::endl;
}
std::map<string, std::vector<const NodeDef*>> output_map;
MapNodesToOutputs(graph, &output_map);
std::vector<const NodeDef*> outputs;
for (const NodeDef& node : graph.node()) {
if (output_map.count(node.name()) == 0) {
outputs.push_back(&node);
}
}
if (outputs.empty()) {
std::cout << "No outputs spotted." << std::endl;
} else {
std::cout << "Found " << outputs.size() << " possible outputs: ";
for (const NodeDef* node : outputs) {
std::cout << "(name=" << node->name();
std::cout << ", op=" << node->op() << ") ";
}
std::cout << std::endl;
}
int const_count = 0;
int variable_count = 0;
int identity_count = 0;
int control_edge_count = 0;
std::map<string, int> device_counts;
for (const NodeDef& node : graph.node()) {
if (node.op() == "Const") {
++const_count;
} else if (node.op() == "Variable") {
++variable_count;
} else if (node.op() == "Identity") {
++identity_count;
}
for (const string& input : node.input()) {
if (input.substr(0, 1) == "^") {
++control_edge_count;
}
}
if (node.device() != "") {
++device_counts[node.device()];
}
}
std::cout << "Found " << const_count << " consts, " << variable_count
<< " variables, " << identity_count << " identities, and "
<< control_edge_count << " control_edges" << std::endl;
if (!device_counts.empty()) {
for (const auto& device_info : device_counts) {
std::cout << device_info.second << " nodes assigned to device '"
<< device_info.first << "'";
}
}
std::vector<std::pair<string, string>> invalid_inputs;
FindInvalidInputs(graph, &invalid_inputs);
if (!invalid_inputs.empty()) {
for (const std::pair<string, string>& invalid_input : invalid_inputs) {
std::cout << "Invalid input " << invalid_input.second << " for node "
<< invalid_input.first << std::endl;
}
return errors::Internal(
"Invalid graph with inputs referring to nonexistent nodes");
}
std::map<string, int> op_counts;
for (const NodeDef& node : graph.node()) {
++op_counts[node.op()];
}
std::vector<std::pair<string, int>> op_counts_vec(op_counts.begin(),
op_counts.end());
std::sort(op_counts_vec.begin(), op_counts_vec.end(),
[](std::pair<string, int> a, std::pair<string, int> b) {
return (a.second > b.second);
});
std::cout << "Op types used: ";
bool is_first = true;
for (const std::pair<string, int>& op_count : op_counts_vec) {
if (!is_first) {
std::cout << ", ";
} else {
is_first = false;
}
std::cout << op_count.second << " " << op_count.first;
}
std::cout << std::endl;
return Status::OK();
}
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
string in_graph = "";
string out_graph = "";
string inputs_string = "";
string outputs_string = "";
string transforms_string = "";
std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"),
};
string usage = Flags::Usage(argv[0], flag_list);
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
// We need to call this to set up global state for TensorFlow.
port::InitMain(argv[0], &argc, &argv);
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
return -1;
}
if (in_graph.empty()) {
LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
return -1;
}
GraphDef graph_def;
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
if (!load_status.ok()) {
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
<< load_status.error_message();
LOG(ERROR) << usage;
return -1;
}
Status summarize_result = SummarizeGraph(graph_def);
if (!summarize_result.ok()) {
LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
return -1;
}
return 0;
}
} // namespace
} // namespace graph_transforms
} // namespace tensorflow
int main(int argc, char* argv[]) {
return tensorflow::graph_transforms::ParseFlagsAndSummarizeGraph(argc, argv);
}

View File

@ -0,0 +1,280 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/transform_graph.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
using tensorflow::strings::Scanner;
Status ParseTransformParameters(const string& transforms_string,
TransformParameters* params_list) {
params_list->clear();
enum {
TRANSFORM_NAME,
TRANSFORM_PARAM_NAME,
TRANSFORM_PARAM_VALUE,
} state = TRANSFORM_NAME;
StringPiece remaining(transforms_string);
StringPiece match;
StringPiece transform_name;
StringPiece parameter_name;
StringPiece parameter_value;
TransformFuncParameters func_parameters;
while (!remaining.empty()) {
if (state == TRANSFORM_NAME) {
// Reset the list of parameters.
func_parameters.clear();
// Eat up any leading spaces.
Scanner(remaining).Any(Scanner::SPACE).GetResult(&remaining, &match);
// See if we have a valid transform name.
const bool found_transform_name =
Scanner(remaining)
.Any(Scanner::LETTER_DIGIT_UNDERSCORE)
.GetResult(&remaining, &transform_name);
if (!found_transform_name) {
return errors::InvalidArgument("Looking for transform name, but found ",
remaining.ToString().c_str());
}
if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
state = TRANSFORM_PARAM_NAME;
} else {
// Add a transform with no parameters.
params_list->push_back({transform_name.ToString(), func_parameters});
transform_name = "";
state = TRANSFORM_NAME;
}
} else if (state == TRANSFORM_PARAM_NAME) {
if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
params_list->push_back({transform_name.ToString(), func_parameters});
transform_name = "";
state = TRANSFORM_NAME;
} else {
// Eat up any leading spaces or commas.
Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match);
Scanner(remaining).Any(Scanner::SPACE).GetResult(&remaining, &match);
// See if we have a valid parameter name.
const bool found_parameter_name =
Scanner(remaining)
.Any(Scanner::LETTER_DIGIT_UNDERSCORE)
.GetResult(&remaining, &parameter_name);
if (!found_parameter_name) {
return errors::InvalidArgument(
"Looking for parameter name, but found ",
remaining.ToString().c_str());
}
if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
state = TRANSFORM_PARAM_VALUE;
} else {
return errors::InvalidArgument("Looking for =, but found ",
remaining.ToString().c_str());
}
}
} else if (state == TRANSFORM_PARAM_VALUE) {
bool found_parameter_value;
// Deal with quoted values.
if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) {
found_parameter_value =
Scanner(remaining).ScanEscapedUntil('"').GetResult(
&remaining, &parameter_value);
if (found_parameter_value) {
Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match);
}
} else {
// See if we have a valid parameter name.
found_parameter_value =
Scanner(remaining)
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
.GetResult(&remaining, &parameter_value);
}
if (!found_parameter_value) {
return errors::InvalidArgument("Looking for parameter name, but found ",
remaining.ToString().c_str());
}
func_parameters[parameter_name.ToString()].push_back(
parameter_value.ToString());
// Eat up any trailing quotes.
Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
state = TRANSFORM_PARAM_NAME;
}
}
return Status::OK();
}
int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
string in_graph = "";
string out_graph = "";
string inputs_string = "";
string outputs_string = "";
string transforms_string = "";
std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"),
Flag("out_graph", &out_graph, "output graph file name"),
Flag("inputs", &inputs_string, "inputs"),
Flag("outputs", &outputs_string, "outputs"),
Flag("transforms", &transforms_string, "list of transforms"),
};
string usage = Flags::Usage(argv[0], flag_list);
usage += "\nTransforms are:\n";
TransformRegistry* transform_registry = GetTransformRegistry();
for (const auto& pair : *transform_registry) {
usage += pair.first + "\n";
}
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
// We need to call this to set up global state for TensorFlow.
if (init_main) {
port::InitMain(argv[0], &argc, &argv);
}
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
return -1;
}
if (in_graph.empty()) {
LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
return -1;
}
if (out_graph.empty()) {
LOG(ERROR) << "out_graph graph can't be empty.\n" << usage;
return -1;
}
if (transforms_string.empty()) {
LOG(ERROR) << "You must specify at least one transform.\n" << usage;
return -1;
}
std::vector<string> inputs = str_util::Split(inputs_string, ',');
std::vector<string> outputs = str_util::Split(outputs_string, ',');
TransformParameters transform_params;
Status parse_status =
ParseTransformParameters(transforms_string, &transform_params);
if (!parse_status.ok()) {
LOG(ERROR) << "Failed to parse --transform argument, error was "
<< parse_status.error_message();
return -1;
}
if (transform_params.empty()) {
LOG(ERROR) << "You must specify at least one transform.\n" << usage;
return -1;
}
GraphDef graph_def;
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
if (!load_status.ok()) {
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
<< load_status.error_message();
LOG(ERROR) << usage;
return -1;
}
Status transform_result =
TransformGraph(inputs, outputs, transform_params, &graph_def);
if (!transform_result.ok()) {
LOG(ERROR) << transform_result.error_message();
LOG(ERROR) << usage;
return -1;
}
Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
if (!save_status.ok()) {
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
<< save_status.error_message();
return -1;
}
return 0;
}
Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
bool* ignore_errors) {
*ignore_errors = false;
if (transform_params.count("ignore_errors") &&
(!transform_params.at("ignore_errors").empty())) {
const string& ignore_errors_string =
str_util::Lowercase(transform_params.at("ignore_errors").at(0));
if (ignore_errors_string == "true") {
*ignore_errors = true;
} else if (ignore_errors_string == "false") {
*ignore_errors = false;
} else {
return errors::InvalidArgument(
"ignore_errors should be true or false, found ",
ignore_errors_string);
}
}
return Status::OK();
}
Status TransformGraph(const std::vector<string>& inputs,
const std::vector<string>& outputs,
const TransformParameters& transform_params,
GraphDef* graph_def) {
TransformRegistry* transform_registry = GetTransformRegistry();
for (const auto& transform_info : transform_params) {
const string& transform_name = transform_info.first;
if (transform_name == "") {
continue;
}
if (!transform_registry->count(transform_name)) {
return errors::InvalidArgument("Transform '", transform_name,
"' not recognized.");
}
LOG(INFO) << "Applying " << transform_name;
const TransformFunc& transform_func =
transform_registry->at(transform_name);
TransformFuncContext context;
context.input_names = inputs;
context.output_names = outputs;
context.params = transform_info.second;
bool ignore_errors;
TF_RETURN_IF_ERROR(
ShouldIgnoreErrors(transform_info.second, &ignore_errors));
GraphDef transformed_graph_def;
Status transform_result =
transform_func(*graph_def, context, &transformed_graph_def);
if (!transform_result.ok()) {
if (ignore_errors) {
LOG(ERROR) << transform_name << ": Ignoring error "
<< transform_result.error_message();
transformed_graph_def = *graph_def;
} else {
return transform_result;
}
}
// Copy over the library from the original input graph.
transformed_graph_def.mutable_library()->CopyFrom(graph_def->library());
TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def));
*graph_def = transformed_graph_def;
}
return Status::OK();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Convenience function to handle argument parsing for the command line tool.
// If init_main is false, we're testing so don't call core initialization.
int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main);
// Handles converting the transforms string into transform names and their
// arguments.
typedef std::vector<std::pair<string, TransformFuncParameters>>
TransformParameters;
Status ParseTransformParameters(const string& transforms_string,
TransformParameters* params_list);
// Applies a series of transformations to the GraphDef. These transforms are
// defined by modules that call REGISTER_GRAPH_TRANSFORM() to associate a
// function with a name string.
Status TransformGraph(const std::vector<string>& inputs,
const std::vector<string>& outputs,
const TransformParameters& transform_params,
GraphDef* graph_def);
} // namespace graph_transforms
} // namespace tensorflow
#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_

View File

@ -0,0 +1,53 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tool that applies a series of transformations to a frozen GraphDef file.
// It takes a flexible list of transforms either on the command line, and runs
// those on the incoming graph to produce the result. This allows you to build a
// processing pipeline when preparing models for deployment.
//
// bazel build tensorflow/tools/graph_transforms/fold_constants_tool &&
// bazel-bin/tensorflow/tools/graph_transforms/fold_constants_tool \
// --in_graph=graph_def.pb \
// --out_graph=transformed_graph_def.pb \
// --inputs=input1,input2 \
// --outputs=output1,output2 \
// --transforms="fold_constants order_nodes"
//
// Parameters:
// in_graph - name of a file with a frozen GraphDef proto in binary format.
// out_graph - name of the output file to save the transformed version to.
// inputs - layer names of the nodes that will be fed data.
// outputs - layer names of the nodes that will be read from after running.
// transforms - space-separated names of the transforms to apply.
//
// List of implemented transforms:
// fold_constants - Merges constant expression subgraphs into single constants,
// which can help reduce the number of ops and make subsequent transforms
// optimizations more effective.
// order_nodes - Sorts the GraphDef nodes in execution order, which can help
// simple inference engines that want to avoid complexity in their executors.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_graph.h"
int main(int argc, char* argv[]) {
return tensorflow::graph_transforms::ParseFlagsAndTransformGraph(argc, argv,
true);
}

View File

@ -0,0 +1,228 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/transform_graph.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declared here so we don't have to expose it in the public header.
Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
bool* ignore_errors);
namespace {
Status test_empty_graph_transform(const GraphDef& graph_def,
const TransformFuncContext& context,
GraphDef* result) {
result->Clear();
return Status::OK();
}
} // namespace
REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform",
test_empty_graph_transform);
class TransformGraphTest : public ::testing::Test {
protected:
void TestConstantFolding() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 100;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const =
Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const =
Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
Output placeholder =
Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
Output mul =
Mul(root.WithOpName("output_expect_remains"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
const string dir = testing::TmpDir();
const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb");
const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb");
TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb,
graph_def_serialized));
std::vector<string> args = {"some_binary",
"--in_graph=" + in_filename_pb,
"--out_graph=" + out_filename_pb,
"--inputs=placeholder_expect_remains",
"--outputs=output_expect_remains",
"--transforms=fold_constants"};
const int argc = 6;
EXPECT_EQ(argc, args.size());
char* argv[argc];
std::vector<char*> char_strings;
for (int i = 0; i < argc; ++i) {
string arg = args[i];
char* char_string = new char[arg.size() + 1];
std::copy_n(arg.c_str(), arg.size() + 1, char_string);
argv[i] = char_string;
char_strings.push_back(char_string);
}
ParseFlagsAndTransformGraph(argc, argv, false);
for (char* char_string : char_strings) {
delete[] char_string;
}
GraphDef out_graph_def;
TF_EXPECT_OK(
ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def));
std::map<string, const NodeDef*> out_node_map;
graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map);
for (const NodeDef& node : out_graph_def.node()) {
const StringPiece name(node.name());
const int occurrence_count = out_node_map.count(node.name());
if (name.ends_with("expect_removed")) {
EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
}
if (name.ends_with("expect_remains")) {
EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
}
}
}
void TestTransformRegistration() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Output placeholder =
Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_EQ(1, graph_def.node().size());
TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}},
&graph_def));
EXPECT_EQ(0, graph_def.node().size());
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
Status no_such_status =
TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def);
EXPECT_TRUE(
StringPiece(no_such_status.ToString()).contains("not recognized"));
}
void TestParseTransformParameters() {
TransformParameters params_list;
ParseTransformParameters("foo", &params_list);
EXPECT_EQ(1, params_list.size());
EXPECT_EQ("foo", params_list[0].first);
EXPECT_TRUE(params_list[0].second.empty());
ParseTransformParameters("foo bar", &params_list);
EXPECT_EQ(2, params_list.size());
EXPECT_EQ("foo", params_list[0].first);
EXPECT_TRUE(params_list[0].second.empty());
EXPECT_EQ("bar", params_list[1].first);
EXPECT_TRUE(params_list[1].second.empty());
ParseTransformParameters("foo() bar()", &params_list);
EXPECT_EQ(2, params_list.size());
EXPECT_EQ("foo", params_list[0].first);
EXPECT_TRUE(params_list[0].second.empty());
EXPECT_EQ("bar", params_list[1].first);
EXPECT_TRUE(params_list[1].second.empty());
ParseTransformParameters("foo(bob_something=sue)", &params_list);
EXPECT_EQ(1, params_list.size());
EXPECT_EQ("foo", params_list[0].first);
EXPECT_EQ(1, params_list[0].second.count("bob_something"));
EXPECT_EQ(1, params_list[0].second["bob_something"].size());
EXPECT_EQ("sue", params_list[0].second["bob_something"][0]);
ParseTransformParameters("bar(a=1, b=2, a=3)", &params_list);
EXPECT_EQ(1, params_list.size());
EXPECT_EQ("bar", params_list[0].first);
EXPECT_EQ(1, params_list[0].second.count("a"));
EXPECT_EQ(2, params_list[0].second["a"].size());
EXPECT_EQ("1", params_list[0].second["a"][0]);
EXPECT_EQ("3", params_list[0].second["a"][1]);
EXPECT_EQ(1, params_list[0].second.count("b"));
EXPECT_EQ(1, params_list[0].second["b"].size());
EXPECT_EQ("2", params_list[0].second["b"][0]);
ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)", &params_list);
EXPECT_EQ(1, params_list.size());
EXPECT_EQ("bar", params_list[0].first);
EXPECT_EQ(1, params_list[0].second.count("a"));
EXPECT_EQ(2, params_list[0].second["a"].size());
EXPECT_EQ("1", params_list[0].second["a"][0]);
EXPECT_EQ("3", params_list[0].second["a"][1]);
EXPECT_EQ(1, params_list[0].second.count("b"));
EXPECT_EQ(1, params_list[0].second["b"].size());
EXPECT_EQ("1,2,3", params_list[0].second["b"][0]);
}
void TestShouldIgnoreErrors() {
bool ignore_errors;
TF_EXPECT_OK(
ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors));
EXPECT_TRUE(ignore_errors);
TF_EXPECT_OK(
ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors));
EXPECT_FALSE(ignore_errors);
TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors));
EXPECT_FALSE(ignore_errors);
EXPECT_FALSE(
ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok());
}
};
TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); }
TEST_F(TransformGraphTest, TestTransformRegistration) {
TestTransformRegistration();
}
TEST_F(TransformGraphTest, TestParseTransformParameters) {
TestParseTransformParameters();
}
TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); }
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -15,12 +15,50 @@ limitations under the License.
#include "tensorflow/tools/graph_transforms/transform_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
namespace tensorflow { namespace tensorflow {
namespace graph_transforms { namespace graph_transforms {
namespace {
inline bool IsMerge(const NodeDef& node_def) {
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
}
void RecordMatchedNodes(const NodeMatch& match,
std::set<string>* matched_nodes) {
matched_nodes->insert(match.node.name());
for (const NodeMatch& input_match : match.inputs) {
RecordMatchedNodes(input_match, matched_nodes);
}
}
inline uint64 Hash64String(const string& input) {
return Hash64(input.data(), input.size());
}
} // namespace
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
std::set<string> found_nodes;
std::vector<NodeMatch> current_matches = {match};
while (!current_matches.empty()) {
std::vector<NodeMatch> next_matches;
for (const NodeMatch& current_match : current_matches) {
if (found_nodes.count(current_match.node.name())) {
continue;
}
found_nodes.insert(current_match.node.name());
result->push_back(current_match.node);
for (const NodeMatch& input_match : current_match.inputs) {
next_matches.push_back(input_match);
}
}
current_matches = next_matches;
}
}
void MapNamesToNodes(const GraphDef& graph_def, void MapNamesToNodes(const GraphDef& graph_def,
std::map<string, const NodeDef*>* result) { std::map<string, const NodeDef*>* result) {
for (const NodeDef& node : graph_def.node()) { for (const NodeDef& node : graph_def.node()) {
@ -28,7 +66,19 @@ void MapNamesToNodes(const GraphDef& graph_def,
} }
} }
void NodeNamePartsFromInput(string input_name, string* prefix, void MapNodesToOutputs(const GraphDef& graph_def,
std::map<string, std::vector<const NodeDef*>>* result) {
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(graph_def, &node_map);
for (const NodeDef& node : graph_def.node()) {
for (const string& input : node.input()) {
string input_node_name = NodeNameFromInput(input);
(*result)[input_node_name].push_back(&node);
}
}
}
void NodeNamePartsFromInput(const string& input_name, string* prefix,
string* node_name, string* suffix) { string* node_name, string* suffix) {
std::vector<string> input_parts = str_util::Split(input_name, ':'); std::vector<string> input_parts = str_util::Split(input_name, ':');
if (input_parts.size() < 2) { if (input_parts.size() < 2) {
@ -45,7 +95,7 @@ void NodeNamePartsFromInput(string input_name, string* prefix,
*node_name = node_name_piece.ToString(); *node_name = node_name_piece.ToString();
} }
string NodeNameFromInput(string input_name) { string NodeNameFromInput(const string& input_name) {
string prefix; string prefix;
string node_name; string node_name;
string suffix; string suffix;
@ -53,6 +103,57 @@ string NodeNameFromInput(string input_name) {
return node_name; return node_name;
} }
string CanonicalInputName(const string& input_name) {
string prefix;
string node_name;
string suffix;
NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
if (suffix == "") {
suffix = ":0";
}
return prefix + node_name + suffix;
}
uint64 HashNodeDef(const NodeDef& node) {
uint64 hash = Hash64String(node.op());
hash = Hash64Combine(hash, Hash64String(node.name()));
for (const string& input : node.input()) {
hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
}
hash = Hash64Combine(hash, Hash64String(node.device()));
std::vector<string> attr_names;
attr_names.reserve(node.attr().size());
for (const auto& attr : node.attr()) {
attr_names.push_back(attr.first);
}
std::sort(attr_names.begin(), attr_names.end());
string attr_serialized;
for (const string& attr_name : attr_names) {
auto attr = node.attr().at(attr_name);
attr.SerializeToString(&attr_serialized);
hash = Hash64Combine(hash, Hash64String(attr_serialized));
}
return hash;
}
void AddNodeInput(const string& input_name, NodeDef* node) {
*(node->mutable_input()->Add()) = input_name;
}
void CopyNodeAttr(const NodeDef& source, const string& source_key,
const string& dest_key, NodeDef* dest) {
CHECK_NE(0, source.attr().count(source_key))
<< "No key '" << source_key << "' found in " << source.DebugString();
(*(dest->mutable_attr()))[dest_key].CopyFrom(source.attr().at(source_key));
}
Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
TensorProto tensor_proto = node.attr().at(key).tensor();
Tensor tensor;
CHECK(tensor.FromProto(tensor_proto));
return tensor;
}
void FilterGraphDef(const GraphDef& input_graph_def, void FilterGraphDef(const GraphDef& input_graph_def,
std::function<bool(const NodeDef&)> selector, std::function<bool(const NodeDef&)> selector,
GraphDef* output_graph_def) { GraphDef* output_graph_def) {
@ -77,5 +178,425 @@ void RemoveAttributes(const GraphDef& input_graph_def,
} }
} }
Status SortByExecutionOrder(const GraphDef& input_graph_def,
GraphDef* output_graph_def) {
const int num_nodes = input_graph_def.node_size();
std::vector<int> ready;
std::vector<int> pending_count;
pending_count.reserve(num_nodes);
std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
std::map<string, int> name_index;
for (int i = 0; i < input_graph_def.node_size(); ++i) {
const NodeDef& node(input_graph_def.node(i));
name_index[node.name()] = i;
}
// Parse the inputs for each node.
for (int n = 0; n < num_nodes; ++n) {
const NodeDef& node_def(input_graph_def.node(n));
if (IsMerge(node_def)) {
// for merge only wait for one non-control input.
int32 num_control_edges = 0;
for (int i = 0; i < node_def.input_size(); ++i) {
StringPiece input_name(node_def.input(i));
if (input_name.starts_with("^")) {
num_control_edges++;
}
}
pending_count.push_back(num_control_edges + 1);
} else {
pending_count.push_back(node_def.input_size());
}
if (node_def.input_size() == 0) {
ready.push_back(n);
continue;
}
for (int i = 0; i < node_def.input_size(); ++i) {
const string& input_name = node_def.input(i);
const string& input_node_name = NodeNameFromInput(input_name);
if (!name_index.count(input_node_name)) {
return errors::InvalidArgument("Node '", node_def.name(),
"': Unknown input node '",
node_def.input(i), "'");
}
outputs[name_index[input_node_name]].push_back(n);
}
}
int processed = 0;
output_graph_def->Clear();
// Process the NodeDefs in topological order.
// Code above sets this up by filling in ready_ with nodes that have no
// inputs, pending_counts_ with the number of inputs for each node and
// outputs_ with the outputs of each node.
while (!ready.empty()) {
int o = ready.back();
ready.pop_back();
++processed;
const NodeDef& node_def(input_graph_def.node(o));
output_graph_def->mutable_node()->Add()->CopyFrom(node_def);
// Update pending_count for outputs.
for (size_t i = 0; i < outputs[o].size(); ++i) {
const int output = outputs[o][i];
pending_count[output]--;
if (pending_count[output] == 0) {
ready.push_back(output);
}
}
}
if (processed < input_graph_def.node_size()) {
return errors::InvalidArgument(input_graph_def.node_size() - processed,
" nodes in a cycle");
}
return Status::OK();
}
string OpTypePattern::DebugString() const {
string result = "{" + op + ", {";
for (const OpTypePattern& input : inputs) {
result += input.DebugString() + ",";
}
result += "}}";
return result;
}
string NodeMatch::DebugString() const {
string result = "{";
result += node.DebugString();
result += ", {";
for (const NodeMatch& input : inputs) {
result += input.DebugString() + ",";
}
result += "}}";
return result;
}
GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
SortByExecutionOrder(graph_def, &graph_def_);
MapNamesToNodes(graph_def_, &node_map_);
}
Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
std::vector<NodeMatch>* matches) {
std::set<string> matched_nodes;
for (const NodeDef& node : graph_def_.node()) {
// Skip any nodes that are already part of a match.
if (matched_nodes.count(node.name())) {
continue;
}
NodeMatch match;
if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
RecordMatchedNodes(match, &matched_nodes);
matches->push_back(match);
}
}
return Status::OK();
}
bool GraphMatcher::DoesOpTypeMatch(
const NodeDef& node, const OpTypePattern& pattern,
const std::set<string>& previously_matched_nodes, NodeMatch* match) {
VLOG(1) << "Looking at node " << node.DebugString();
VLOG(1) << "pattern=" << pattern.DebugString();
VLOG(1) << "match=" << match->DebugString();
if (previously_matched_nodes.count(node.name())) {
VLOG(1) << "node " << node.name() << " has been previously matched";
return false;
}
bool pattern_matched = false;
if (pattern.op == "*") {
pattern_matched = true;
} else {
std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
for (const string& pattern_op : pattern_ops) {
if (node.op() == pattern_op) {
pattern_matched = true;
}
}
}
if (!pattern_matched) {
VLOG(1) << "node.op() != pattern.op()";
return false;
}
match->node = node;
// Ignore any control inputs for pattern-matching purposes
std::vector<string> non_control_inputs;
for (const string& input : node.input()) {
if (!input.empty() && (input[0] != '^')) {
non_control_inputs.push_back(input);
}
}
if (pattern.inputs.empty()) {
// If there are no inputs, assume that's the end of the pattern.
return true;
}
if (non_control_inputs.size() != pattern.inputs.size()) {
VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
return false;
}
for (int i = 0; i < pattern.inputs.size(); ++i) {
const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
const NodeDef& input_node = *(node_map_[input_node_name]);
const OpTypePattern& input_pattern = pattern.inputs[i];
match->inputs.push_back(NodeMatch());
NodeMatch* input_match = &(match->inputs.back());
if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
input_match)) {
return false;
}
}
return true;
}
Status ReplaceMatchingOpTypes(
const GraphDef& input_graph_def, const OpTypePattern& pattern,
const std::function<Status(const NodeMatch&, const std::set<string>&,
const std::set<string>&, std::vector<NodeDef>*)>&
node_generator,
const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
// Start off by retrieving all the matching subgraphs.
GraphMatcher matcher(input_graph_def);
std::vector<NodeMatch> matches;
matcher.GetOpTypeMatches(pattern, &matches);
// Do some housekeeping so we can easily look up the resulting matches given
// a node name.
std::set<string> matched_nodes;
std::map<string, const NodeMatch*> matches_by_head_name;
for (const NodeMatch& match : matches) {
matches_by_head_name[match.node.name()] = &match;
RecordMatchedNodes(match, &matched_nodes);
}
std::map<string, std::vector<const NodeDef*>> outputs_map;
MapNodesToOutputs(input_graph_def, &outputs_map);
// Go through all the nodes in the input graph, see if they are part of a
// match or if they can be left untouched.
output_graph_def->Clear();
for (const NodeDef& input_node : input_graph_def.node()) {
if (matches_by_head_name.count(input_node.name())) {
// This node is the beginning of a match, so call the replacement function
// after setting up some information it will need.
const NodeMatch* match = matches_by_head_name[input_node.name()];
std::vector<NodeDef> matched_nodes_array;
MatchedNodesAsArray(*match, &matched_nodes_array);
// This tells us whether a node is part of the current match.
std::set<string> matched_nodes_lookup;
for (const NodeDef& matched_node : matched_nodes_array) {
matched_nodes_lookup.insert(matched_node.name());
}
// These are helper arrays that the replacement function can use to tell
// whether it can safely remove an internal node (because nothing outside
// of the match uses it) or whether external nodes depend on it.
std::set<string> input_nodes;
std::set<string> output_nodes;
for (const NodeDef& matched_node : matched_nodes_array) {
// Look through all of this node's inputs, and if any of them come from
// outside the match, then this should be noted as one of the external
// inputs of the subgraph.
for (const string& input_name : matched_node.input()) {
string input_node_name = NodeNameFromInput(input_name);
if (!matched_nodes_lookup.count(input_node_name)) {
input_nodes.insert(matched_node.name());
}
}
// Do a reverse input lookup, to see which other nodes use the current
// one as an input. If any of those nodes are outside the match
// subgraph, then the current node is marked as an output node that
// shouldn't be removed.
if (outputs_map.count(matched_node.name())) {
for (const NodeDef* dependent_node :
outputs_map[matched_node.name()]) {
if (!matched_nodes_lookup.count(dependent_node->name())) {
output_nodes.insert(matched_node.name());
}
}
}
}
// Call the generator function and add all the returned nodes to the
// graph.
std::vector<NodeDef> new_nodes;
TF_RETURN_IF_ERROR(
node_generator(*match, input_nodes, output_nodes, &new_nodes));
std::set<string> new_node_names;
for (const NodeDef& new_node : new_nodes) {
new_node_names.insert(new_node.name());
}
// Check to make sure the generator function preserved all of the nodes
// that are used elsewhere in the graph, and add them back in if not.
bool abort_replacement = false;
if (!options.allow_inconsistencies) {
for (const string& expected_output : output_nodes) {
if (!new_node_names.count(expected_output)) {
LOG(WARNING) << "Expected " << expected_output
<< " to be preserved.";
abort_replacement = true;
}
}
}
if (abort_replacement) {
LOG(WARNING) << "Generator function didn't preserve needed nodes, "
<< "copying old replacements back in instead.";
std::vector<NodeDef> old_nodes;
MatchedNodesAsArray(*match, &old_nodes);
for (const NodeDef& old_node : old_nodes) {
NodeDef* added_node = output_graph_def->mutable_node()->Add();
added_node->CopyFrom(old_node);
}
} else {
for (const NodeDef& new_node : new_nodes) {
NodeDef* added_node = output_graph_def->mutable_node()->Add();
added_node->CopyFrom(new_node);
}
}
} else if (!matched_nodes.count(input_node.name())) {
// This node isn't part of any match, so just copy it over.
NodeDef* added_node = output_graph_def->mutable_node()->Add();
added_node->CopyFrom(input_node);
} else {
// Do nothing, because this is an internal part of a matching subgraph,
// and so will have been replaced by a new replacement subgraph.
}
}
return Status::OK();
}
Status RenameNodeInputs(const GraphDef& input_graph_def,
const std::map<string, string>& inputs_to_rename,
GraphDef* output_graph_def) {
std::map<string, std::vector<std::pair<string, string>>>
canonical_inputs_to_rename;
for (const auto& input_to_rename : inputs_to_rename) {
canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
.push_back({input_to_rename.first, input_to_rename.second});
}
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
new_node->mutable_input()->Clear();
for (const string& input_name : node.input()) {
std::set<string> already_visited;
string new_input_name = input_name;
while (
canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
string input_node_name = NodeNameFromInput(new_input_name);
if (already_visited.count(input_node_name)) {
return errors::InvalidArgument(
"RenameNodeInputs argument contains a cycle for ",
input_node_name);
}
already_visited.insert(input_node_name);
bool any_match_found = false;
for (const std::pair<string, string>& input_to_rename :
canonical_inputs_to_rename.at(input_node_name)) {
const string& source_name = input_to_rename.first;
const string& dest_name = input_to_rename.second;
bool is_match;
string match_name;
if (StringPiece(source_name).ends_with(":*")) {
is_match = true;
string prefix;
string unused_node_name;
string suffix;
NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
&suffix);
match_name = prefix + dest_name + suffix;
} else {
is_match = (CanonicalInputName(source_name) ==
CanonicalInputName(new_input_name));
match_name = dest_name;
}
if (is_match) {
new_input_name = match_name;
any_match_found = true;
}
}
if (!any_match_found) {
break;
}
}
*(new_node->mutable_input()->Add()) = new_input_name;
}
}
return Status::OK();
}
void CopyOriginalMatch(const NodeMatch& match,
std::vector<NodeDef>* new_nodes) {
std::vector<NodeDef> old_nodes;
MatchedNodesAsArray(match, &old_nodes);
for (const NodeDef& old_node : old_nodes) {
new_nodes->push_back(old_node);
}
}
TransformRegistry* GetTransformRegistry() {
static TransformRegistry transform_registry;
return &transform_registry;
}
void FindInvalidInputs(const GraphDef& graph_def,
std::vector<std::pair<string, string>>* invalid_inputs) {
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(graph_def, &node_map);
for (const NodeDef& node : graph_def.node()) {
for (const string& input : node.input()) {
string input_node = NodeNameFromInput(input);
if (!node_map.count(input_node)) {
invalid_inputs->push_back({node.name(), input_node});
}
}
}
}
Status IsGraphValid(const GraphDef& graph_def) {
std::vector<std::pair<string, string>> invalid_inputs;
FindInvalidInputs(graph_def, &invalid_inputs);
if (!invalid_inputs.empty()) {
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(graph_def, &node_map);
for (const std::pair<string, string>& invalid_input : invalid_inputs) {
LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
<< invalid_input.first << " - "
<< node_map[invalid_input.first]->DebugString();
}
return errors::Internal(
"Invalid graph with inputs referring to nonexistent nodes");
}
return Status::OK();
}
int CountParameters(const TransformFuncContext& context, const string& name) {
if (context.params.count(name)) {
return context.params.at(name).size();
} else {
return 0;
}
}
Status GetExactlyOneParameter(const TransformFuncContext& context,
const string& name, const string& default_value,
string* result) {
const int params_count = CountParameters(context, name);
if (params_count == 0) {
*result = default_value;
return Status::OK();
} else if (params_count == 1) {
*result = context.params.at(name).at(0);
return Status::OK();
} else {
return errors::InvalidArgument("Expected a single '", name,
"' parameter, but found ", params_count,
" occurrences");
}
}
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,7 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ #ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
#include <set>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
@ -26,18 +30,73 @@ namespace graph_transforms {
void MapNamesToNodes(const GraphDef& graph_def, void MapNamesToNodes(const GraphDef& graph_def,
std::map<string, const NodeDef*>* result); std::map<string, const NodeDef*>* result);
// For every node in the graph create a list of the nodes that use it as an
// input.
void MapNodesToOutputs(const GraphDef& graph_def,
std::map<string, std::vector<const NodeDef*>>* result);
// NodeDef input strings can contain other information besides the name of an // NodeDef input strings can contain other information besides the name of an
// input node. These include: // input node. These include:
// - Optional '^' prefix, indicating this is a control edge. // - Optional '^' prefix, indicating this is a control edge.
// - The required name of the input node. // - The required name of the input node.
// - Option ':<number>' suffix, showing which output of the node to use. // - Optional ':<number>' suffix, showing which output of the node to use.
// This function takes a raw string, and breaks it into those component parts. // This function takes a raw string, and breaks it into those component parts.
void NodeNamePartsFromInput(string input_name, string* prefix, // The rules for inputs in function libraries are a bit more complex, and
// aren't handled by this routine.
void NodeNamePartsFromInput(const string& input_name, string* prefix,
string* node_name, string* suffix); string* node_name, string* suffix);
// Adds a ':0' port to any inputs with no suffix, to make comparisons easier.
string CanonicalInputName(const string& input_name);
// Convenience function to strip the optional prefix and suffix components from // Convenience function to strip the optional prefix and suffix components from
// a string pulled from a NodeDef input, and return the plain node name. // a string pulled from a NodeDef input, and return the plain node name.
string NodeNameFromInput(string input_name); string NodeNameFromInput(const string& input_name);
// Returns a stable hash for the contents of the NodeDef, so that equivalent
// nodes should have equal hashes.
uint64 HashNodeDef(const NodeDef& node);
// Adds the given node name to the end of the node's inputs.
void AddNodeInput(const string& input_name, NodeDef* node);
// Copies an attribute from one NodeDef to another.
void CopyNodeAttr(const NodeDef& source, const string& source_key,
const string& dest_key, NodeDef* dest);
// Inserts a value into a NodeDef's map of attributes.
// This is a bit different than AddNodeAttr in node_def_util.h because it
// overwrites any existing attributes with the same key.
template <class T>
inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
auto* attr_map = node->mutable_attr();
(*attr_map)[key] = attr_value;
}
template <class T>
inline void SetNodeTensorAttr(const string& key, const Tensor& tensor,
NodeDef* node) {
TensorProto tensor_proto;
tensor.AsProtoTensorContent(&tensor_proto);
SetNodeAttr(key, tensor_proto, node);
}
// Inserts a Tensor into the specified attribute of a NodeDef.
template <class T>
inline void SetNodeTensorAttr(const string& key, const TensorShape& shape,
const std::vector<T>& values, NodeDef* node) {
const DataType dtype = DataTypeToEnum<T>::v();
CHECK_EQ(shape.num_elements(), values.size());
Tensor tensor(dtype, shape);
T* dest_data = tensor.flat<T>().data();
std::copy_n(values.data(), values.size(), dest_data);
SetNodeTensorAttr<T>(key, tensor, node);
}
// Retrieves a tensor value from a NodeDef attribute.
Tensor GetNodeTensorAttr(const NodeDef& node, const string& key);
// Creates a copy of the input GraphDef, but only containing the nodes where the // Creates a copy of the input GraphDef, but only containing the nodes where the
// supplied selector function returned true. // supplied selector function returned true.
@ -51,6 +110,144 @@ void RemoveAttributes(const GraphDef& input_graph_def,
const std::vector<string>& attributes, const std::vector<string>& attributes,
GraphDef* output_graph_def); GraphDef* output_graph_def);
// For a lot of replacement and matching operations it's useful to have the
// nodes processed in a controlled order, so this does a topological sort to
// ensure that nodes always appear in the GraphDef.node list after their inputs.
Status SortByExecutionOrder(const GraphDef& input_graph_def,
GraphDef* output_graph_def);
// Finds inputs that refer to nodes that are not in the graph.
void FindInvalidInputs(const GraphDef& graph_def,
std::vector<std::pair<string, string>>* invalid_inputs);
// Returns a descriptive error status if there are problems spotted with the
// graph.
Status IsGraphValid(const GraphDef& graph_def);
// This is used to spot particular subgraphs in a larger model. To use it,
// create a pattern like:
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
// This defines a subgraph where a Conv2D has a ResizeBilinear input, which
// pulls from a MirrorPad op.
// Regular expressions aren't supported for the op names, but you can use "*" to
// match any op. You can also use | as a separator to match multiple op names,
// like "Reshape|Concat|Conv2D".
struct OpTypePattern {
string op;
std::vector<OpTypePattern> inputs;
string DebugString() const;
};
// Returns a sub-graph of nodes that match a pattern.
struct NodeMatch {
NodeMatch() : node() {}
NodeDef node;
std::vector<NodeMatch> inputs;
string DebugString() const;
};
// Utility class to spot subgraphs matching particular patterns.
class GraphMatcher {
public:
GraphMatcher(const GraphDef& graph_def);
// Sorts the input nodes into execution order, and then skips any previously
// matches so that no node appears in more than one match. The NodeDef
// pointers contained in the results are owned by the GraphMatcher object, and
// so will be invalid after its lifetime.
Status GetOpTypeMatches(const OpTypePattern& pattern,
std::vector<NodeMatch>* matches);
private:
bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern,
const std::set<string>& previously_matched_nodes,
NodeMatch* match);
GraphDef graph_def_;
std::map<string, const NodeDef*> node_map_;
};
struct ReplaceMatchingOpTypesOptions {
// Whether to raise an error if the graph is left with dangling inputs. If you
// enable this option, you must fix inconsistencies in a later pass.
bool allow_inconsistencies;
};
// Replaces all of the matching sub-graphs with new ops. This calls into the
// given function, and expects to receive a set of new nodes to replace each
// matched sub-graph. It has some logic to protect the integrity of the
// resulting graph, for example making sure that nodes needed by other nodes
// outside the sub-graph aren't removed. These are passed in as the set of
// outputs, and nodes with the same names must be added to the new nodes
// produced by the replacement function. Many of these checks can be disabled
// by setting allow_inconsistencies to true in the options, but then it's the
// caller's responsibility to patch up any problems before passing on the graph
// to others. There's more comprehensive usage documentation in the README.
Status ReplaceMatchingOpTypes(
const GraphDef& input_graph_def, const OpTypePattern& pattern,
const std::function<Status(const NodeMatch&, const std::set<string>&,
const std::set<string>&, std::vector<NodeDef>*)>&
node_generator,
const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def);
// Returns a list of the unique nodes found in this match.
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
// Changes all input references to a particular node name.
Status RenameNodeInputs(const GraphDef& input_graph_def,
const std::map<string, string>& inputs_to_rename,
GraphDef* output_graph_def);
// Utility function that copies all the nodes found in a match into the
// new_nodes list. This is useful in replacement functions when you decide to
// leave the original matched subgraph untouched and make no changes.
void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes);
// Holds information that's needed for transform functions.
typedef std::map<string, std::vector<string>> TransformFuncParameters;
struct TransformFuncContext {
std::vector<string> input_names;
std::vector<string> output_names;
TransformFuncParameters params;
};
// Returns how many occurrences of the given parameter are present.
int CountParameters(const TransformFuncContext& context, const string& name);
// Gets a simple occurrence of a parameter, using a default if it isn't present.
Status GetExactlyOneParameter(const TransformFuncContext& context,
const string& name, const string& default_value,
string* result);
// This is the function API for all graph transformations, taking an input
// GraphDef and other arguments, and returning a transformed GraphDef.
typedef std::function<Status(const GraphDef&,
const TransformFuncContext& context, GraphDef*)>
TransformFunc;
// To add a new graph transform function, call the macro:
// REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
// Under the hood this adds the function to the list of known transforms, so you
// just need to link in the .cc file with your registration call to have access
// to it through the command line tool.
// The rest of the machinery below is to enable that automagical registration.
typedef std::map<string, TransformFunc> TransformRegistry;
TransformRegistry* GetTransformRegistry();
class TransformRegistrar {
public:
TransformRegistrar(const string& name, TransformFunc transform_func) {
TransformRegistry* transform_registry = GetTransformRegistry();
(*transform_registry)[name] = transform_func;
}
};
#define REGISTER_GRAPH_TRANSFORM(name, func) \
REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func)
#define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \
REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)
#define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) \
static tensorflow::graph_transforms::TransformRegistrar \
registrar__body__##ctr##__object(name, func);
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow

View File

@ -52,8 +52,8 @@ class TransformUtilsTest : public ::testing::Test {
GraphDef graph_def; GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def)); TF_ASSERT_OK(root.ToGraphDef(&graph_def));
std::map<string, const NodeDef*> node_map; std::map<string, const NodeDef*> node_map;
MapNamesToNodes(graph_def, &node_map); MapNamesToNodes(graph_def, &node_map);
EXPECT_EQ(1, node_map.count("a")); EXPECT_EQ(1, node_map.count("a"));
EXPECT_EQ(1, node_map.count("b")); EXPECT_EQ(1, node_map.count("b"));
EXPECT_EQ(1, node_map.count("add")); EXPECT_EQ(1, node_map.count("add"));
@ -62,6 +62,52 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("no_such_node")); EXPECT_EQ(0, node_map.count("no_such_node"));
} }
void TestMapNodesToOutputs() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 100;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Output add = Add(root.WithOpName("add"), a_const, b_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
std::map<string, std::vector<const NodeDef*>> outputs_map;
MapNodesToOutputs(graph_def, &outputs_map);
EXPECT_EQ(1, outputs_map.count("a"));
EXPECT_EQ(1, outputs_map["a"].size());
EXPECT_EQ("add", outputs_map["a"][0]->name());
EXPECT_EQ(1, outputs_map.count("b"));
EXPECT_EQ(1, outputs_map["b"].size());
EXPECT_EQ("add", outputs_map["b"][0]->name());
EXPECT_EQ(1, outputs_map.count("add"));
EXPECT_EQ(1, outputs_map["add"].size());
EXPECT_EQ("output", outputs_map["add"][0]->name());
EXPECT_EQ(1, outputs_map.count("placeholder"));
EXPECT_EQ(1, outputs_map["placeholder"].size());
EXPECT_EQ("output", outputs_map["placeholder"][0]->name());
EXPECT_EQ(0, outputs_map.count("output"));
EXPECT_EQ(0, outputs_map.count("no_such_node"));
}
void TestNodeNamePartsFromInput() { void TestNodeNamePartsFromInput() {
string prefix; string prefix;
string node_name; string node_name;
@ -101,6 +147,75 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42")); EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42"));
} }
void TestCanonicalInputName() {
EXPECT_EQ("node_name:0", CanonicalInputName("node_name"));
EXPECT_EQ("node_name:0", CanonicalInputName("node_name:0"));
EXPECT_EQ("^node_name:0", CanonicalInputName("^node_name"));
EXPECT_EQ("^node_name:42", CanonicalInputName("^node_name:42"));
}
void TestAddNodeInput() {
NodeDef node;
AddNodeInput("foo", &node);
EXPECT_EQ("foo", node.input(0));
}
void TestCopyNodeAttr() {
NodeDef node;
auto mutable_attr = node.mutable_attr();
(*mutable_attr)["foo"].set_i(3);
NodeDef copied_node;
CopyNodeAttr(node, "foo", "bar", &copied_node);
EXPECT_EQ(3, copied_node.attr().at("bar").i());
}
void TestSetNodeAttr() {
NodeDef node;
int32 value_i = 32;
SetNodeAttr("foo", value_i, &node);
EXPECT_EQ(32, node.attr().at("foo").i());
string value_s = "some_value";
SetNodeAttr("bar", value_s, &node);
EXPECT_EQ("some_value", node.attr().at("bar").s());
}
void TestSetNodeTensorAttr() {
NodeDef node;
SetNodeTensorAttr<int32>("foo", {3, 1}, {1, 2, 3}, &node);
TensorProto tensor_proto = node.attr().at("foo").tensor();
Tensor tensor;
CHECK(tensor.FromProto(tensor_proto));
EXPECT_EQ(DT_INT32, tensor.dtype());
EXPECT_EQ(3, tensor.shape().dim_size(0));
EXPECT_EQ(1, tensor.shape().dim_size(1));
EXPECT_EQ(1, tensor.flat<int32>()(0));
EXPECT_EQ(2, tensor.flat<int32>()(1));
EXPECT_EQ(3, tensor.flat<int32>()(2));
}
void TestSetNodeTensorAttrWithTensor() {
NodeDef node;
Tensor input_tensor(DT_INT32, {4, 5});
test::FillIota<int32>(&input_tensor, 1);
SetNodeTensorAttr<int32>("foo", input_tensor, &node);
TensorProto tensor_proto = node.attr().at("foo").tensor();
Tensor tensor;
CHECK(tensor.FromProto(tensor_proto));
test::ExpectTensorEqual<int32>(input_tensor, tensor);
}
void TestGetNodeTensorAttr() {
NodeDef node;
Tensor input_tensor(DT_INT32, {4, 5});
test::FillIota<int32>(&input_tensor, 1);
TensorProto tensor_proto;
input_tensor.AsProtoTensorContent(&tensor_proto);
SetNodeAttr("foo", tensor_proto, &node);
Tensor result = GetNodeTensorAttr(node, "foo");
test::ExpectTensorEqual<int32>(input_tensor, result);
}
void TestFilterGraphDef() { void TestFilterGraphDef() {
auto root = tensorflow::Scope::NewRootScope(); auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces) using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@ -160,19 +275,679 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ(nullptr, EXPECT_EQ(nullptr,
tensorflow::AttrSlice(*removed_placeholder).Find("dtype")); tensorflow::AttrSlice(*removed_placeholder).Find("dtype"));
} }
void TestGetOpTypeMatches() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 100;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Output add = Add(root.WithOpName("add"), a_const, b_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphMatcher matcher(graph_def);
std::vector<NodeMatch> const_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Const"}, &const_matches));
EXPECT_EQ(2, const_matches.size());
for (const NodeMatch& match : const_matches) {
EXPECT_EQ("Const", match.node.op());
EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
<< "match.node.name()=" << match.node.name();
}
std::vector<NodeMatch> add_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add"}, &add_matches));
EXPECT_EQ(1, add_matches.size());
EXPECT_EQ("Add", add_matches[0].node.op());
EXPECT_EQ("add", add_matches[0].node.name());
std::vector<NodeMatch> add_child_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
&add_child_matches));
EXPECT_EQ(1, add_child_matches.size());
EXPECT_EQ("Add", add_child_matches[0].node.op());
EXPECT_EQ("add", add_child_matches[0].node.name());
EXPECT_EQ(2, add_child_matches[0].inputs.size());
for (const NodeMatch& match : add_child_matches[0].inputs) {
EXPECT_EQ("Const", match.node.op());
EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
<< "match.node.name()=" << match.node.name();
}
std::vector<NodeMatch> no_such_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches({"NoSuch"}, &no_such_matches));
EXPECT_EQ(0, no_such_matches.size());
std::vector<NodeMatch> all_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches(
{"Mul", {{"Add", {{"Const"}, {"Const"}}}, {"Placeholder"}}},
&all_matches));
EXPECT_EQ(1, all_matches.size());
EXPECT_EQ("Mul", all_matches[0].node.op());
EXPECT_EQ("output", all_matches[0].node.name());
EXPECT_EQ(2, all_matches[0].inputs.size());
EXPECT_EQ("Add", all_matches[0].inputs[0].node.op());
EXPECT_EQ("add", all_matches[0].inputs[0].node.name());
EXPECT_EQ(2, all_matches[0].inputs[0].inputs.size());
EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[0].node.op());
EXPECT_EQ("a", all_matches[0].inputs[0].inputs[0].node.name());
EXPECT_EQ(0, all_matches[0].inputs[0].inputs[0].inputs.size());
EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[1].node.op());
EXPECT_EQ("b", all_matches[0].inputs[0].inputs[1].node.name());
EXPECT_EQ(0, all_matches[0].inputs[0].inputs[1].inputs.size());
EXPECT_EQ("Placeholder", all_matches[0].inputs[1].node.op());
EXPECT_EQ("placeholder", all_matches[0].inputs[1].node.name());
EXPECT_EQ(0, all_matches[0].inputs[1].inputs.size());
std::vector<NodeMatch> wildcard_matches;
TF_ASSERT_OK(
matcher.GetOpTypeMatches({"*", {{"*"}, {"*"}}}, &wildcard_matches));
EXPECT_EQ(1, wildcard_matches.size());
EXPECT_EQ("Add", wildcard_matches[0].node.op());
EXPECT_EQ("Const", wildcard_matches[0].inputs[0].node.op());
EXPECT_EQ("a", wildcard_matches[0].inputs[0].node.name());
EXPECT_EQ("Const", wildcard_matches[0].inputs[1].node.op());
EXPECT_EQ("b", wildcard_matches[0].inputs[1].node.name());
std::vector<NodeMatch> or_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add|Mul"}, &or_matches));
EXPECT_EQ(2, or_matches.size());
EXPECT_EQ("Add", or_matches[0].node.op());
EXPECT_EQ("add", or_matches[0].node.name());
EXPECT_EQ("Mul", or_matches[1].node.op());
EXPECT_EQ("output", or_matches[1].node.name());
}
void TestGetOpTypeMatchesDAG() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 100;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Output add = Add(root.WithOpName("add"), a_const, a_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphMatcher matcher(graph_def);
std::vector<NodeMatch> add_matches;
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
&add_matches));
EXPECT_EQ(1, add_matches.size());
EXPECT_EQ("Add", add_matches[0].node.op());
EXPECT_EQ("add", add_matches[0].node.name());
EXPECT_EQ("Const", add_matches[0].inputs[0].node.op());
EXPECT_EQ("a", add_matches[0].inputs[0].node.name());
EXPECT_EQ("Const", add_matches[0].inputs[1].node.op());
EXPECT_EQ("a", add_matches[0].inputs[1].node.name());
}
void TestReplaceMatchingOpTypes() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Output add = Add(root.WithOpName("add"), a_const, b_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef replaced_graph_def;
TF_ASSERT_OK(ReplaceMatchingOpTypes(
graph_def, {"*"},
[](const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
NodeDef original_copy;
original_copy.CopyFrom(match.node);
const string original_name = match.node.name();
original_copy.set_name(original_name + "_before_identity");
new_nodes->push_back(original_copy);
NodeDef identity_node;
identity_node.set_op("Identity");
identity_node.set_name(original_name);
*(identity_node.mutable_input()->Add()) = original_copy.name();
new_nodes->push_back(identity_node);
return Status::OK();
},
{}, &replaced_graph_def));
EXPECT_EQ(10, replaced_graph_def.node_size());
for (const NodeDef& node : replaced_graph_def.node()) {
if (node.name() == "output") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("output_before_identity", node.input(0));
} else if (node.name() == "output_before_identity") {
EXPECT_EQ("Mul", node.op());
EXPECT_EQ("add", node.input(0));
EXPECT_EQ("placeholder", node.input(1));
} else if (node.name() == "placeholder") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("placeholder_before_identity", node.input(0));
} else if (node.name() == "placeholder_before_identity") {
EXPECT_EQ("Placeholder", node.op());
} else if (node.name() == "add") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("add_before_identity", node.input(0));
} else if (node.name() == "add_before_identity") {
EXPECT_EQ("Add", node.op());
EXPECT_EQ("a", node.input(0));
EXPECT_EQ("b", node.input(1));
} else if (node.name() == "a") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("a_before_identity", node.input(0));
} else if (node.name() == "a_before_identity") {
EXPECT_EQ("Const", node.op());
} else if (node.name() == "b") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("b_before_identity", node.input(0));
} else if (node.name() == "b_before_identity") {
EXPECT_EQ("Const", node.op());
} else {
EXPECT_EQ(true, false) << "Unexpected node name found: " << node.name();
}
}
}
void TestMatchedNodesAsArray() {
NodeMatch fourth;
fourth.node.set_name("fourth");
NodeMatch second;
second.node.set_name("second");
second.inputs.push_back(fourth);
NodeMatch third;
third.node.set_name("third");
third.inputs.push_back(fourth);
NodeMatch first;
first.node.set_name("first");
first.inputs.push_back(second);
first.inputs.push_back(third);
std::vector<NodeDef> result;
MatchedNodesAsArray(first, &result);
EXPECT_EQ(4, result.size());
EXPECT_EQ("first", result[0].name());
EXPECT_EQ("second", result[1].name());
EXPECT_EQ("third", result[2].name());
EXPECT_EQ("fourth", result[3].name());
}
void TestRenameNodeInputs() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Output add = Add(root.WithOpName("add"), a_const, a_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
EXPECT_EQ("b", node_map.at("add")->input(0));
EXPECT_EQ("b", node_map.at("add")->input(1));
}
void TestRenameNodeInputsWithRedirects() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Tensor c_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&c_data, 1.0f);
Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
Output add = Add(root.WithOpName("add"), a_const, b_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(
graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
&renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
EXPECT_EQ("c", node_map.at("add")->input(0));
EXPECT_EQ("b", node_map.at("add")->input(1));
}
void TestRenameNodeInputsWithCycle() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
Tensor c_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&c_data, 1.0f);
Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
Output add = Add(root.WithOpName("add"), a_const, b_const);
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
Output mul = Mul(root.WithOpName("output"), add, placeholder);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
&renamed_graph_def);
EXPECT_FALSE(rename_status.ok());
}
void TestRenameNodeInputsWithWildcard() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
QuantizeV2 quantize_a(root.WithOpName("quantize_a"), a_const, a_const,
a_const, DT_QUINT8,
QuantizeV2::Attrs().Mode("MIN_FIRST"));
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
QuantizeV2 quantize_b(root.WithOpName("quantize_b"), b_const, b_const,
b_const, DT_QUINT8,
QuantizeV2::Attrs().Mode("MIN_FIRST"));
Output add = Add(root.WithOpName("add"), quantize_a.output_min,
quantize_a.output_max);
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef renamed_graph_def;
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
&renamed_graph_def));
std::map<string, const NodeDef*> node_map;
MapNamesToNodes(renamed_graph_def, &node_map);
EXPECT_EQ("quantize_b:1", node_map.at("add")->input(0));
EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
}
void TestFindInvalidInputs() {
GraphDef graph_def;
NodeDef* mul_node = graph_def.mutable_node()->Add();
mul_node->set_op("Mul");
mul_node->set_name("mul_node");
*(mul_node->mutable_input()->Add()) = "add_node1";
*(mul_node->mutable_input()->Add()) = "add_node2:0";
*(mul_node->mutable_input()->Add()) = "^const_node1:0";
NodeDef* add_node1 = graph_def.mutable_node()->Add();
add_node1->set_op("Add");
add_node1->set_name("add_node1");
*(add_node1->mutable_input()->Add()) = "missing_input1";
*(add_node1->mutable_input()->Add()) = "const_node1:0";
*(add_node1->mutable_input()->Add()) = "missing_input2";
NodeDef* add_node2 = graph_def.mutable_node()->Add();
add_node2->set_op("Add");
add_node2->set_name("add_node2");
*(add_node2->mutable_input()->Add()) = "missing_input3";
*(add_node2->mutable_input()->Add()) = "const_node1:0";
*(add_node2->mutable_input()->Add()) = "^const_node2";
NodeDef* const_node1 = graph_def.mutable_node()->Add();
const_node1->set_op("Const");
const_node1->set_name("const_node1");
NodeDef* const_node2 = graph_def.mutable_node()->Add();
const_node2->set_op("Const");
const_node2->set_name("const_node2");
std::vector<std::pair<string, string>> invalid_inputs;
FindInvalidInputs(graph_def, &invalid_inputs);
EXPECT_EQ(3, invalid_inputs.size());
for (const std::pair<string, string>& invalid_input : invalid_inputs) {
EXPECT_TRUE((invalid_input.first == "add_node1") ||
(invalid_input.first == "add_node2"));
if (invalid_input.first == "add_node1") {
EXPECT_TRUE((invalid_input.second == "missing_input1") ||
(invalid_input.second == "missing_input2"))
<< invalid_input.second;
} else if (invalid_input.first == "add_node2") {
EXPECT_EQ("missing_input3", invalid_input.second);
}
}
}
void TestIsGraphValid() {
GraphDef invalid_graph_def;
NodeDef* mul_node = invalid_graph_def.mutable_node()->Add();
mul_node->set_op("Mul");
mul_node->set_name("mul_node");
*(mul_node->mutable_input()->Add()) = "add_node1";
*(mul_node->mutable_input()->Add()) = "add_node2:0";
*(mul_node->mutable_input()->Add()) = "^const_node1:0";
NodeDef* add_node1 = invalid_graph_def.mutable_node()->Add();
add_node1->set_op("Add");
add_node1->set_name("add_node1");
*(add_node1->mutable_input()->Add()) = "missing_input1";
*(add_node1->mutable_input()->Add()) = "const_node1:0";
*(add_node1->mutable_input()->Add()) = "missing_input2";
NodeDef* add_node2 = invalid_graph_def.mutable_node()->Add();
add_node2->set_op("Add");
add_node2->set_name("add_node2");
*(add_node2->mutable_input()->Add()) = "missing_input3";
*(add_node2->mutable_input()->Add()) = "const_node1:0";
*(add_node2->mutable_input()->Add()) = "^const_node2";
NodeDef* const_node1 = invalid_graph_def.mutable_node()->Add();
const_node1->set_op("Const");
const_node1->set_name("const_node1");
NodeDef* const_node2 = invalid_graph_def.mutable_node()->Add();
const_node2->set_op("Const");
const_node2->set_name("const_node2");
EXPECT_FALSE(IsGraphValid(invalid_graph_def).ok());
GraphDef valid_graph_def;
NodeDef* const_node3 = valid_graph_def.mutable_node()->Add();
const_node3->set_op("Const");
const_node3->set_name("const_node2");
EXPECT_TRUE(IsGraphValid(valid_graph_def).ok());
}
void TestCopyOriginalMatch() {
NodeDef a;
a.set_op("Relu");
a.set_name("a");
AddNodeInput("b", &a);
NodeDef b;
b.set_op("Const");
b.set_name("b");
NodeMatch b_match;
b_match.node = b;
NodeMatch a_match;
a_match.node = a;
a_match.inputs.push_back(b_match);
std::vector<NodeDef> new_nodes;
CopyOriginalMatch(a_match, &new_nodes);
EXPECT_EQ(2, new_nodes.size());
EXPECT_EQ("a", new_nodes[0].name());
EXPECT_EQ("Relu", new_nodes[0].op());
EXPECT_EQ("b", new_nodes[1].name());
EXPECT_EQ("Const", new_nodes[1].op());
}
void TestHashNodeDef() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
auto a_root = tensorflow::Scope::NewRootScope();
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(a_root.WithOpName("a"), Input::Initializer(a_data));
GraphDef a_graph_def;
TF_ASSERT_OK(a_root.ToGraphDef(&a_graph_def));
const NodeDef& a_node_def = a_graph_def.node(0);
auto b_root = tensorflow::Scope::NewRootScope();
Tensor b_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&b_data, 1.0f);
Output b_const = Const(b_root.WithOpName("a"), Input::Initializer(b_data));
GraphDef b_graph_def;
TF_ASSERT_OK(b_root.ToGraphDef(&b_graph_def));
const NodeDef& b_node_def = b_graph_def.node(0);
auto c_root = tensorflow::Scope::NewRootScope();
Tensor c_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&c_data, 2.0f);
Output c_const = Const(c_root.WithOpName("a"), Input::Initializer(c_data));
GraphDef c_graph_def;
TF_ASSERT_OK(c_root.ToGraphDef(&c_graph_def));
const NodeDef& c_node_def = c_graph_def.node(0);
auto d_root = tensorflow::Scope::NewRootScope();
Tensor d_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&d_data, 1.0f);
Output d_const = Const(d_root.WithOpName("d"), Input::Initializer(d_data));
GraphDef d_graph_def;
TF_ASSERT_OK(d_root.ToGraphDef(&d_graph_def));
const NodeDef& d_node_def = d_graph_def.node(0);
auto e_root = tensorflow::Scope::NewRootScope();
Tensor e_data(DT_INT32, TensorShape({width}));
test::FillIota<int32>(&e_data, 1);
Output e_const = Const(e_root.WithOpName("a"), Input::Initializer(e_data));
GraphDef e_graph_def;
TF_ASSERT_OK(e_root.ToGraphDef(&e_graph_def));
const NodeDef& e_node_def = e_graph_def.node(0);
auto f_root = tensorflow::Scope::NewRootScope();
Tensor f_data(DT_FLOAT, TensorShape({width - 1}));
test::FillIota<float>(&f_data, 1.0f);
Output f_const = Const(f_root.WithOpName("a"), Input::Initializer(f_data));
GraphDef f_graph_def;
TF_ASSERT_OK(f_root.ToGraphDef(&f_graph_def));
const NodeDef& f_node_def = f_graph_def.node(0);
auto g_root = tensorflow::Scope::NewRootScope();
Tensor g_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&g_data, 1);
Output g_const = Const(g_root.WithOpName("a").WithDevice("some_device"),
Input::Initializer(g_data));
GraphDef g_graph_def;
TF_ASSERT_OK(g_root.ToGraphDef(&g_graph_def));
const NodeDef& g_node_def = g_graph_def.node(0);
NodeDef relu1_node_def;
relu1_node_def.set_op("Relu");
relu1_node_def.set_name("a");
relu1_node_def.add_input("foo");
NodeDef relu2_node_def;
relu2_node_def.set_op("Relu");
relu2_node_def.set_name("a");
relu2_node_def.add_input("bar");
EXPECT_EQ(HashNodeDef(a_node_def), HashNodeDef(b_node_def));
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(c_node_def));
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(d_node_def));
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(e_node_def));
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(f_node_def));
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(g_node_def));
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(relu1_node_def));
EXPECT_NE(HashNodeDef(relu1_node_def), HashNodeDef(relu2_node_def));
}
void TestCountParameters() {
TransformFuncContext context;
context.params.insert({"foo", {"a", "b"}});
context.params.insert({"bar", {"c"}});
EXPECT_EQ(2, CountParameters(context, "foo"));
EXPECT_EQ(1, CountParameters(context, "bar"));
EXPECT_EQ(0, CountParameters(context, "not_present"));
}
void TestGetExactlyOneParameter() {
TransformFuncContext context;
context.params.insert({"foo", {"a", "b"}});
context.params.insert({"bar", {"c"}});
string value;
TF_EXPECT_OK(GetExactlyOneParameter(context, "bar", "d", &value));
EXPECT_EQ("c", value);
EXPECT_FALSE(GetExactlyOneParameter(context, "foo", "d", &value).ok());
TF_EXPECT_OK(GetExactlyOneParameter(context, "not_present", "d", &value));
EXPECT_EQ("d", value);
}
}; };
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
TEST_F(TransformUtilsTest, TestMapNodesToOutputs) { TestMapNodesToOutputs(); }
TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) { TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) {
TestNodeNamePartsFromInput(); TestNodeNamePartsFromInput();
} }
TEST_F(TransformUtilsTest, TestCanonicalInputName) { TestCanonicalInputName(); }
TEST_F(TransformUtilsTest, TestAddNodeInput) { TestAddNodeInput(); }
TEST_F(TransformUtilsTest, TestCopyNodeAttr) { TestCopyNodeAttr(); }
TEST_F(TransformUtilsTest, TestSetNodeAttr) { TestSetNodeAttr(); }
TEST_F(TransformUtilsTest, TestSetNodeTensorAttr) { TestSetNodeTensorAttr(); }
TEST_F(TransformUtilsTest, TestSetNodeTensorAttrWithTensor) {
TestSetNodeTensorAttrWithTensor();
}
TEST_F(TransformUtilsTest, TestGetNodeTensorAttr) { TestGetNodeTensorAttr(); }
TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); } TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); }
TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); } TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); }
TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); } TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); }
TEST_F(TransformUtilsTest, TestGetOpTypeMatches) { TestGetOpTypeMatches(); }
TEST_F(TransformUtilsTest, TestGetOpTypeMatchesDAG) {
TestGetOpTypeMatchesDAG();
}
TEST_F(TransformUtilsTest, TestReplaceMatchingOpTypes) {
TestReplaceMatchingOpTypes();
}
TEST_F(TransformUtilsTest, TestMatchedNodesAsArray) {
TestMatchedNodesAsArray();
}
TEST_F(TransformUtilsTest, TestRenameNodeInputs) { TestRenameNodeInputs(); }
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithRedirects) {
TestRenameNodeInputsWithRedirects();
}
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithCycle) {
TestRenameNodeInputsWithCycle();
}
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
TestRenameNodeInputsWithWildcard();
}
TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); }
TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) {
TestGetExactlyOneParameter();
}
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow