Create Graph Transform Tool for rewriting model files.
Change: 142729497
This commit is contained in:
parent
be60473c88
commit
0f0e29e7ba
@ -252,6 +252,16 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
|
|||||||
bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
||||||
FunctionLibraryRuntime* function_library, Env* env,
|
FunctionLibraryRuntime* function_library, Env* env,
|
||||||
Device* partition_device, Graph* graph) {
|
Device* partition_device, Graph* graph) {
|
||||||
|
bool was_mutated;
|
||||||
|
Status unused_status = DoConstantFoldingWithStatus(
|
||||||
|
opts, function_library, env, partition_device, graph, &was_mutated);
|
||||||
|
return was_mutated;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
|
||||||
|
FunctionLibraryRuntime* function_library,
|
||||||
|
Env* env, Device* partition_device,
|
||||||
|
Graph* graph, bool* was_mutated) {
|
||||||
DumpGraph("Before", graph);
|
DumpGraph("Before", graph);
|
||||||
|
|
||||||
const FunctionLibraryDefinition* flib_def = nullptr;
|
const FunctionLibraryDefinition* flib_def = nullptr;
|
||||||
@ -263,7 +273,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
|||||||
FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes);
|
FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes);
|
||||||
if (constant_foldable_nodes.empty()) {
|
if (constant_foldable_nodes.empty()) {
|
||||||
VLOG(1) << "No constant foldable nodes found";
|
VLOG(1) << "No constant foldable nodes found";
|
||||||
return false;
|
*was_mutated = false;
|
||||||
|
// This is not an error, so return the status as OK.
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<NodeAndOutput, Node*> tensors_to_fetch;
|
std::map<NodeAndOutput, Node*> tensors_to_fetch;
|
||||||
@ -273,7 +285,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
|||||||
|
|
||||||
if (tensors_to_fetch.empty()) {
|
if (tensors_to_fetch.empty()) {
|
||||||
VLOG(1) << "No constant nodes found that feed into the original graph.";
|
VLOG(1) << "No constant nodes found that feed into the original graph.";
|
||||||
return false;
|
*was_mutated = false;
|
||||||
|
// This is not an error, so return the status as OK.
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
|
VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
|
||||||
<< graph->num_node_ids();
|
<< graph->num_node_ids();
|
||||||
@ -292,7 +306,9 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
|||||||
{} /* inputs*/, tensors_to_fetch_names, &outputs);
|
{} /* inputs*/, tensors_to_fetch_names, &outputs);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
VLOG(1) << "Could not fetch constants: " << s;
|
VLOG(1) << "Could not fetch constants: " << s;
|
||||||
return false;
|
*was_mutated = false;
|
||||||
|
// This is not an error, so return the status as OK.
|
||||||
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the constant tensors and replace the corresponding tensors in the
|
// Fetch the constant tensors and replace the corresponding tensors in the
|
||||||
@ -307,7 +323,8 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
|||||||
|
|
||||||
DumpGraph("After", graph);
|
DumpGraph("After", graph);
|
||||||
|
|
||||||
return num_nodes_replaced > 0;
|
*was_mutated = (num_nodes_replaced > 0);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -29,7 +29,16 @@ namespace tensorflow {
|
|||||||
// and replaces those nodes with the result of the evaluation.
|
// and replaces those nodes with the result of the evaluation.
|
||||||
// "partition_device", if non-null, is the device where all the graph nodes are
|
// "partition_device", if non-null, is the device where all the graph nodes are
|
||||||
// assumed to execute.
|
// assumed to execute.
|
||||||
// Returns true if and only if "graph" has been mutated.
|
// Sets `was_mutated` to true if and only if "graph" has been mutated.
|
||||||
|
// The status is only set to a non-OK state if an unexpected error is hit
|
||||||
|
// running the graph.
|
||||||
|
Status DoConstantFoldingWithStatus(const ConstantFoldingOptions& opts,
|
||||||
|
FunctionLibraryRuntime* function_library,
|
||||||
|
Env* env, Device* partition_device,
|
||||||
|
Graph* graph, bool* was_mutated);
|
||||||
|
|
||||||
|
// Version of the function that doesn't return a Status, for backwards
|
||||||
|
// compatibility.
|
||||||
bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
bool DoConstantFolding(const ConstantFoldingOptions& opts,
|
||||||
FunctionLibraryRuntime* function_library, Env* env,
|
FunctionLibraryRuntime* function_library, Env* env,
|
||||||
Device* partition_device, Graph* graph);
|
Device* partition_device, Graph* graph);
|
||||||
|
@ -228,8 +228,12 @@ TEST_F(ConstantFoldingTest, TestNoReplaceLargeConstant) {
|
|||||||
g->AddControlEdge(concat_send, g->sink_node());
|
g->AddControlEdge(concat_send, g->sink_node());
|
||||||
|
|
||||||
// The above concat should not have been constant folded.
|
// The above concat should not have been constant folded.
|
||||||
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
|
bool was_mutated;
|
||||||
Env::Default(), nullptr, g));
|
Status status =
|
||||||
|
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||||
|
Env::Default(), nullptr, g, &was_mutated);
|
||||||
|
EXPECT_FALSE(was_mutated);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
||||||
@ -257,8 +261,12 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) {
|
|||||||
g->AddControlEdge(times_two_send, g->sink_node());
|
g->AddControlEdge(times_two_send, g->sink_node());
|
||||||
|
|
||||||
// The above function call should not have been constant folded.
|
// The above function call should not have been constant folded.
|
||||||
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
|
bool was_mutated;
|
||||||
Env::Default(), nullptr, g));
|
status =
|
||||||
|
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||||
|
Env::Default(), nullptr, g, &was_mutated);
|
||||||
|
EXPECT_FALSE(was_mutated);
|
||||||
|
EXPECT_FALSE(status.ok());
|
||||||
|
|
||||||
g_ = nullptr;
|
g_ = nullptr;
|
||||||
}
|
}
|
||||||
@ -337,10 +345,16 @@ TEST_F(ConstantFoldingTest, TestImmutableConst) {
|
|||||||
auto result2 = ops::MatMul(root, result1, c);
|
auto result2 = ops::MatMul(root, result1, c);
|
||||||
TF_ASSERT_OK(root.ToGraph(g));
|
TF_ASSERT_OK(root.ToGraph(g));
|
||||||
TestTFEnvironment test_env;
|
TestTFEnvironment test_env;
|
||||||
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, nullptr,
|
bool was_mutated;
|
||||||
Env::Default(), nullptr, g));
|
Status status =
|
||||||
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, &test_env,
|
DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||||
nullptr, g));
|
Env::Default(), nullptr, g, &was_mutated);
|
||||||
|
EXPECT_FALSE(was_mutated);
|
||||||
|
EXPECT_FALSE(status.ok());
|
||||||
|
status = DoConstantFoldingWithStatus(ConstantFoldingOptions{}, nullptr,
|
||||||
|
&test_env, nullptr, g, &was_mutated);
|
||||||
|
EXPECT_TRUE(was_mutated);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -41,11 +41,12 @@ template <typename Device, typename T>
|
|||||||
class QuantizeV2Op : public OpKernel {
|
class QuantizeV2Op : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
half_range_ = !std::is_signed<T>::value
|
half_range_ =
|
||||||
? 0.0f
|
!std::is_signed<T>::value
|
||||||
: (std::numeric_limits<T>::max() -
|
? 0.0f
|
||||||
std::numeric_limits<T>::min() + 1) /
|
: (static_cast<double>(std::numeric_limits<T>::max()) -
|
||||||
2.0f;
|
static_cast<double>(std::numeric_limits<T>::min()) + 1) /
|
||||||
|
2.0f;
|
||||||
string mode_string;
|
string mode_string;
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
|
||||||
OP_REQUIRES(ctx,
|
OP_REQUIRES(ctx,
|
||||||
@ -90,7 +91,8 @@ class QuantizeV2Op : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
|
||||||
if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
|
if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
|
||||||
const float scale_factor =
|
const float scale_factor =
|
||||||
(std::numeric_limits<T>::max() - std::numeric_limits<T>::min()) /
|
(static_cast<double>(std::numeric_limits<T>::max()) -
|
||||||
|
static_cast<double>(std::numeric_limits<T>::min())) /
|
||||||
(max_range - min_range);
|
(max_range - min_range);
|
||||||
|
|
||||||
// Quantize:
|
// Quantize:
|
||||||
@ -162,5 +164,8 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
|
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
|
||||||
QuantizeV2Op<CPUDevice, qint16>);
|
QuantizeV2Op<CPUDevice, qint16>);
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
|
||||||
|
QuantizeV2Op<CPUDevice, qint32>);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -47,6 +47,46 @@ TEST_F(QuantizedOpTest, QuantizeV2) {
|
|||||||
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
|
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(QuantizedOpTest, QuantizeV2_32Bit) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Attr("T", DataTypeToEnum<qint32>::v())
|
||||||
|
.Attr("mode", "MIN_FIRST")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
const int element_count = 8;
|
||||||
|
AddInputFromArray<float>(
|
||||||
|
TensorShape({element_count}),
|
||||||
|
{-500.0f, 0.0f, 1.0f, 1.25f, 1.75f, 127.0f, 255.0f, 500.0f});
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {-256.0f});
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {256.0f});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_QINT32, TensorShape({element_count}));
|
||||||
|
test::FillValues<qint32>(&expected,
|
||||||
|
{
|
||||||
|
std::numeric_limits<int32>::min(), 0,
|
||||||
|
static_cast<int32>(1.0f * (1 << 23)),
|
||||||
|
static_cast<int32>(1.25f * (1 << 23)),
|
||||||
|
static_cast<int32>(1.75f * (1 << 23)),
|
||||||
|
static_cast<int32>(127.0f * (1 << 23)),
|
||||||
|
static_cast<int32>(255.0f * (1 << 23)),
|
||||||
|
std::numeric_limits<int32>::max(),
|
||||||
|
});
|
||||||
|
// We expect there will be some fuzziness in the lower bits, since this is
|
||||||
|
// converting from float.
|
||||||
|
const int64 epsilon = 1 << 8;
|
||||||
|
const qint32* output_data = GetOutput(0)->flat<qint32>().data();
|
||||||
|
const qint32* expected_data = expected.flat<qint32>().data();
|
||||||
|
for (int i = 0; i < element_count; ++i) {
|
||||||
|
const int64 delta = output_data[i] - expected_data[i];
|
||||||
|
EXPECT_GT(epsilon, std::abs(delta))
|
||||||
|
<< "output_data[" << i << "]=" << output_data[i] << ", expected_data["
|
||||||
|
<< i << "]=" << expected_data[i] << ", delta=" << delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(QuantizedOpTest, QuantizeV2Ports) {
|
TEST_F(QuantizedOpTest, QuantizeV2Ports) {
|
||||||
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
|
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "QuantizeV2")
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
@ -29,6 +29,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
],
|
],
|
||||||
@ -52,9 +53,23 @@ tf_cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "fold_constants_lib",
|
name = "transforms_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"fold_batch_norms.cc",
|
||||||
"fold_constants_lib.cc",
|
"fold_constants_lib.cc",
|
||||||
|
"fold_old_batch_norms.cc",
|
||||||
|
"fuse_convolutions.cc",
|
||||||
|
"obsfucate_names.cc",
|
||||||
|
"quantize_nodes.cc",
|
||||||
|
"quantize_weights.cc",
|
||||||
|
"remove_attribute.cc",
|
||||||
|
"remove_device.cc",
|
||||||
|
"remove_nodes.cc",
|
||||||
|
"rename_attribute.cc",
|
||||||
|
"rename_op.cc",
|
||||||
|
"round_weights.cc",
|
||||||
|
"sort_by_execution_order.cc",
|
||||||
|
"strip_unused_nodes.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"fold_constants_lib.h",
|
"fold_constants_lib.h",
|
||||||
@ -65,20 +80,98 @@ cc_library(
|
|||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core/kernels:quantized_ops",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "transforms_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"fold_batch_norms_test.cc",
|
||||||
|
"fold_constants_test.cc",
|
||||||
|
"fold_old_batch_norms_test.cc",
|
||||||
|
"fuse_convolutions_test.cc",
|
||||||
|
"obsfucate_names_test.cc",
|
||||||
|
"quantize_nodes_test.cc",
|
||||||
|
"quantize_weights_test.cc",
|
||||||
|
"remove_attribute_test.cc",
|
||||||
|
"remove_device_test.cc",
|
||||||
|
"remove_nodes_test.cc",
|
||||||
|
"rename_attribute_test.cc",
|
||||||
|
"rename_op_test.cc",
|
||||||
|
"round_weights_test.cc",
|
||||||
|
"sort_by_execution_order_test.cc",
|
||||||
|
"strip_unused_nodes_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":transform_utils",
|
||||||
|
":transforms_lib",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:sendrecv_ops",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels:quantized_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "transform_graph_lib",
|
||||||
|
srcs = ["transform_graph.cc"],
|
||||||
|
hdrs = ["transform_graph.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":transform_utils",
|
||||||
|
":transforms_lib",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# This library includes a main function, to make it easy to create other
|
||||||
|
# versions of the tool linked against different operator libs.
|
||||||
|
cc_library(
|
||||||
|
name = "transform_graph_main_lib",
|
||||||
|
srcs = ["transform_graph_main.cc"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":transform_graph_lib",
|
||||||
|
":transforms_lib",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "transform_graph",
|
||||||
|
copts = tf_copts(),
|
||||||
|
linkstatic = 1,
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":transform_graph_main_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "fold_constants_test",
|
name = "transform_graph_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["fold_constants_test.cc"],
|
srcs = ["transform_graph_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":fold_constants_lib",
|
":transform_graph_lib",
|
||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/cc:sendrecv_ops",
|
"//tensorflow/cc:sendrecv_ops",
|
||||||
@ -95,23 +188,24 @@ tf_cc_test(
|
|||||||
# This library includes a main function, to make it easy to create other
|
# This library includes a main function, to make it easy to create other
|
||||||
# versions of the tool linked against different operator libs.
|
# versions of the tool linked against different operator libs.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "fold_constants_main_lib",
|
name = "summarize_graph_main_lib",
|
||||||
srcs = ["fold_constants_tool.cc"],
|
srcs = ["summarize_graph_main.cc"],
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":fold_constants_lib",
|
":transform_utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "fold_constants_tool",
|
name = "summarize_graph",
|
||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":fold_constants_main_lib",
|
":summarize_graph_main_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
858
tensorflow/tools/graph_transforms/README.md
Normal file
858
tensorflow/tools/graph_transforms/README.md
Normal file
@ -0,0 +1,858 @@
|
|||||||
|
# Graph Transform Tool
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
* [Introduction](#introduction)
|
||||||
|
* [Using the Graph Transform Tool](#using-the-graph-transform-tool)
|
||||||
|
* [Inspecting Graphs](#inspecting-graphs)
|
||||||
|
* [Common Use Cases](#common-use-cases)
|
||||||
|
* [Optimizing for Deployment](#optimizing-for-deployment)
|
||||||
|
* [Fixing Missing Kernel Errors on
|
||||||
|
Mobile](#fixing-missing-kernel-errors-on-mobile)
|
||||||
|
* [Shrinking File Size](#shrinking-file-size)
|
||||||
|
* [Eight-bit Calculations](#eight-bit-calculations)
|
||||||
|
* [Transform Reference](#transform-reference)
|
||||||
|
* [fold_batch_norms](#fold_batch_norms)
|
||||||
|
* [fold_constants](#fold_constants)
|
||||||
|
* [fold_old_batch_norms](#fold_old_batch_norms)
|
||||||
|
* [fuse_convolutions](#fuse_convolutions)
|
||||||
|
* [merge_duplicate_nodes](#merge_duplicate_nodes)
|
||||||
|
* [obsfucate_names](#obsfucate_names)
|
||||||
|
* [quantize_nodes](#quantize_nodes)
|
||||||
|
* [quantize_weights](#quantize_weights)
|
||||||
|
* [remove_attribute](#remove_attribute)
|
||||||
|
* [remove_device](#remove_device)
|
||||||
|
* [remove_nodes](#remove_nodes)
|
||||||
|
* [rename_attribute](#rename_attribute)
|
||||||
|
* [rename_op](#rename_op)
|
||||||
|
* [round_weights](#round_weights)
|
||||||
|
* [sort_by_execution_order](#sort_by_execution_order)
|
||||||
|
* [strip_unused_nodes](#strip_unused_nodes)
|
||||||
|
* [Writing Your Own Transforms](#writing-your-own-transforms)
|
||||||
|
* [Transform Functions](#transform-functions)
|
||||||
|
* [Pattern Syntax](#pattern-syntax)
|
||||||
|
* [ReplaceMatchingOpTypes](#replacematchingoptypes)
|
||||||
|
* [Parameters](#parameters)
|
||||||
|
* [Function Libraries](#function-libraries)
|
||||||
|
* [Registering](#registering)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
When you have finished training a model and want to deploy it in production,
|
||||||
|
you'll often want to modify it to better run in its final environment. For
|
||||||
|
example if you're targeting a phone you might want to shrink the file size by
|
||||||
|
quantizing the weights, or optimize away batch normalization or other
|
||||||
|
training-only features. The Graph Transform framework offers a suite of tools
|
||||||
|
for modifying computational graphs, and a framework to make it easy to write
|
||||||
|
your own modifications.
|
||||||
|
|
||||||
|
This guide is structured into three main parts, first giving some tutorials on
|
||||||
|
how to perform common tasks, second a reference covering all of the different
|
||||||
|
transformations that are included, together with the options that apply to them,
|
||||||
|
and third a guide to creating your own transforms.
|
||||||
|
|
||||||
|
## Using the Graph Transform Tool
|
||||||
|
|
||||||
|
The Graph Transform tool is designed to work on models that are saved as
|
||||||
|
GraphDef files, usually in a binary protobuf format. This is the low-level
|
||||||
|
definition of a TensorFlow computational graph, including a list of nodes and
|
||||||
|
the input and output connections between them. If you're using a Python API to
|
||||||
|
train your model, this will usually be saved out in the same directory as your
|
||||||
|
checkpoints, and usually has a '.pb' suffix.
|
||||||
|
|
||||||
|
If you want to work with the values of your trained parameters, for example to
|
||||||
|
quantize weights, you'll need to run
|
||||||
|
[tensorflow/python/tools/freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)
|
||||||
|
to convert the checkpoint values into embedded constants within the graph file
|
||||||
|
itself.
|
||||||
|
|
||||||
|
You call the Graph Transform tool itself like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
strip_unused_nodes(type=float, shape="1,299,299,3") \
|
||||||
|
remove_nodes(op=Identity, op=CheckNumerics) \
|
||||||
|
fold_old_batch_norms \
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
The arguments here are specifying where to read the graph from, where to write
|
||||||
|
the transformed version to, what the input and output layers are, and what
|
||||||
|
transforms to modify the graph with. The transforms are given as a list of
|
||||||
|
names, and can each have arguments themselves. These transforms define the
|
||||||
|
pipeline of modifications that are applied in order to produce the output.
|
||||||
|
Sometimes you need some transforms to happen before others, and the ordering
|
||||||
|
within the list lets you specify which happen first.
|
||||||
|
|
||||||
|
## Inspecting Graphs
|
||||||
|
|
||||||
|
Many of the transforms that the tool supports need to know what the input and
|
||||||
|
output layers of the model are. The best source for these is the model training
|
||||||
|
process, where for a classifier the inputs will be the nodes that receive the
|
||||||
|
data from the training set, and the output will be the predictions. If you're
|
||||||
|
unsure, the
|
||||||
|
[summarize_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/summarize_graph.cc)
|
||||||
|
can inspect the model and provide guesses about likely input and output nodes,
|
||||||
|
as well as other information that's useful for debugging. Here's an example of
|
||||||
|
how to use it on the [Inception V3
|
||||||
|
graph](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:summarize_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=tensorflow_inception_graph.pb
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
This section has small guides for some of the most frequently-used
|
||||||
|
transformation pipelines, aimed at users who want to quickly accomplish one of
|
||||||
|
these tasks. A lot of them will use the Inception V3 model for their examples,
|
||||||
|
which can be downloaded from
|
||||||
|
[http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
|
||||||
|
|
||||||
|
### Optimizing for Deployment
|
||||||
|
|
||||||
|
If you've finished training your model and want to deploy it on a server or a
|
||||||
|
mobile device, you'll want it to run as fast as possible, and with as few
|
||||||
|
non-essential dependencies as you can. This recipe removes all of the nodes that
|
||||||
|
aren't called during inference, shrinks expressions that are always constant
|
||||||
|
into single nodes, and optimizes away some multiply operations used during batch
|
||||||
|
normalization by pre-multiplying the weights for convolutions.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
strip_unused_nodes(type=float, shape="1,299,299,3") \
|
||||||
|
remove_nodes(op=Identity, op=CheckNumerics) \
|
||||||
|
fold_constants(ignore_errors=true) \
|
||||||
|
fold_batch_norms \
|
||||||
|
fold_old_batch_norms\
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
The batch norm folding is included twice because there are two different flavors
|
||||||
|
of batch normalization used in TensorFlow. The older version was implemented
|
||||||
|
with a single BatchNormWithGlobalNormalization op, but it was deprecated in
|
||||||
|
favor of a more recent approach using individual ops to implement the same
|
||||||
|
computation. The two transforms are in there so that both styles are recognized
|
||||||
|
and optimized.
|
||||||
|
|
||||||
|
### Fixing Missing Kernel Errors on Mobile
|
||||||
|
|
||||||
|
The mobile version of TensorFlow is focused on inference, and so by default the
|
||||||
|
list of supported ops (defined in
|
||||||
|
[tensorflow/core/kernels/BUILD:android_extended_ops](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/BUILD#L2452)
|
||||||
|
for Bazel and
|
||||||
|
[tensorflow/contrib/makefile/tf_op_files.txt](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/makefile/tf_op_files.txt)
|
||||||
|
for make builds) doesn't include a lot that are training related. This can cause
|
||||||
|
`No OpKernel was registered to support Op` errors when a GraphDef is loaded,
|
||||||
|
even if the op isn't going to be executed.
|
||||||
|
|
||||||
|
If you see this error and it's an op that you do actually want to run on mobile,
|
||||||
|
then you'll need to make local modifications to the build files to include the
|
||||||
|
right .cc file that defines it. In a lot of cases the op is just a vestigial
|
||||||
|
remnant from the training process though, and if that's true then you can run
|
||||||
|
the [strip_unused_nodes](#strip_unused_nodes), specifying the inputs and outputs
|
||||||
|
of your inference usage, to remove those unneccessary nodes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
strip_unused_nodes(type=float, shape="1,299,299,3") \
|
||||||
|
fold_constants \
|
||||||
|
fold_batch_norms \
|
||||||
|
fold_old_batch_norms\
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Shrinking File Size
|
||||||
|
|
||||||
|
If you're looking to deploy your model as part of a mobile app, then keeping the
|
||||||
|
download size as small as possible is important. For most TensorFlow models, the
|
||||||
|
largest contributors to the file size are the weights passed in to convolutional
|
||||||
|
and fully-connected layers, so anything that can reduce the storage size for
|
||||||
|
those is very useful. Luckily most neural networks are resistant to noise, so
|
||||||
|
it's possible to change the representation of those weights in a lossy way
|
||||||
|
without losing very much accuracy overall.
|
||||||
|
|
||||||
|
On both iOS and Android app packages are compressed before download, so the
|
||||||
|
simplest way to reduce the bandwidth your users need to receive your app is to
|
||||||
|
provide raw data that compresses more easily. By default the weights are stored
|
||||||
|
as floating-point values, and even tiny differences between numbers result in
|
||||||
|
very different bit patterns, and so these don't compress very well. If you round
|
||||||
|
the weights so that nearby numbers are stored as exactly the same values, the
|
||||||
|
resulting bit stream has a lot more repetition and so compresses down a lot more
|
||||||
|
effectively. To try this technique on your model, run the
|
||||||
|
[round_weights](#round_weights] transform.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
round_weights(num_steps=256) \
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see that the `optimized_inception_graph.pb` output file is the same
|
||||||
|
size as the input, but if you run zip on it to compress it, it's almost 70%
|
||||||
|
smaller than if you zip the original! The nice thing about this transform is
|
||||||
|
that it doesn't change the structure of the graph at all, so it's running
|
||||||
|
exactly the same operations and should have the same latency and memory usage as
|
||||||
|
before. You can adjust the `num_steps` parameter to control how many values each
|
||||||
|
weight buffer is rounded to, so lower numbers will increase the compression at
|
||||||
|
the cost of accuracy.
|
||||||
|
|
||||||
|
As a further step, you can store the weights into eight-bit values directly.
|
||||||
|
Here's the recipe for that:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
quantize_weights \
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see that the size of the output graph is about a quarter of the
|
||||||
|
original. The downside to this approach compared to round_weights is that extra
|
||||||
|
decompression ops are inserted to convert the eight-bit values back into
|
||||||
|
floating point, but optimizations in TensorFlow's runtime should ensure these
|
||||||
|
results are cached and so you shouldn't see the graph run any more slowly.
|
||||||
|
|
||||||
|
So far we've been concentrating on weights because those generally take up the
|
||||||
|
most space. If you have a graph with a lot of small nodes in it, the names of
|
||||||
|
those nodes can start to take up a noticeable amount of space too. To shrink
|
||||||
|
those down, you can run the [obsfucate_names](#obsfucate_names) transform, which
|
||||||
|
replaces all the names (except for inputs and outputs) with short, cryptic but
|
||||||
|
unique ids:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
obsfucate_names \
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Eight-bit Calculations
|
||||||
|
|
||||||
|
For some platforms it's very helpful to be able to do as many calculations as
|
||||||
|
possible in eight-bit, rather than floating-point. The support for this in
|
||||||
|
TensorFlow is still experimental and evolving, but you can convert models into
|
||||||
|
quantized form using the graph transform tool:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||||
|
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||||
|
--in_graph=tensorflow_inception_graph.pb \
|
||||||
|
--out_graph=optimized_inception_graph.pb \
|
||||||
|
--inputs='Mul:0' \
|
||||||
|
--outputs='softmax:0' \
|
||||||
|
--transforms='\
|
||||||
|
strip_unused_nodes(type=float, shape="1,299,299,3") \
|
||||||
|
remove_nodes(op=Identity, op=CheckNumerics) \
|
||||||
|
fold_old_batch_norms \
|
||||||
|
quantize_weights \
|
||||||
|
quantize_nodes \
|
||||||
|
strip_unused_nodes \
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
This process converts all the operations in the graph that have quantized
|
||||||
|
equivalents, and leaves the rest in floating point. Only a subset of ops are
|
||||||
|
supported, and on many platforms the quantized code may actually be slower than
|
||||||
|
the float equivalents, but this is a way of increasing performance substantially
|
||||||
|
when all the circumstances are right.
|
||||||
|
|
||||||
|
A full guide to optimizing for quantization is beyond the scope of this guide,
|
||||||
|
but one thing that can help is using the FakeQuantWithMinMaxVars op after Conv2D
|
||||||
|
or similar operations during training. This trains the min/max variables that
|
||||||
|
control the range used for quantization, so that the range doesn't have to be
|
||||||
|
calculated dynamically by RequantizationRange during inference.
|
||||||
|
|
||||||
|
## Transform Reference
|
||||||
|
|
||||||
|
The transforms string is parsed as a series of transform names, each of which
|
||||||
|
can have multiple named arguments inside parentheses. Arguments are separated by
|
||||||
|
commas, and double-quotes (") can be used to hold argument values if they
|
||||||
|
themselves contain commas (for example shape definitions).
|
||||||
|
|
||||||
|
The --inputs and --outputs are shared across all transforms, since it's common
|
||||||
|
to need to know what the ingoing and outgoing nodes in the graph are. You should
|
||||||
|
make sure you set these correctly before calling the graph transform tool, and
|
||||||
|
if you're in doubt check with the model's author, or use the `check_graph` tool
|
||||||
|
to examine likely inputs and outputs.
|
||||||
|
|
||||||
|
All transforms can be passed the `ignore_errors` flag, with the value set to
|
||||||
|
either true or false. By default any errors that happen within a transform will
|
||||||
|
abort the whole process, but if you enable this then an error will just be
|
||||||
|
logged and the transform skipped. This is especially useful for optional
|
||||||
|
transforms where version errors or other unimportant problems may trigger an
|
||||||
|
error.
|
||||||
|
|
||||||
|
### fold_batch_norms
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
This transform tries to optimize away the Mul that's introduced after a Conv2D
|
||||||
|
when batch normalization has been used during training. It scans the graph for
|
||||||
|
any channel-wise multiplies immediately after convolutions, and multiplies the
|
||||||
|
convolution's weights with the mul instead so it can be omitted. You'll need to
|
||||||
|
make sure you run [fold_constants](#fold_constants) first, since the pattern can
|
||||||
|
only be spotted if the normal complex expression that's produced by training for
|
||||||
|
the Mul input is collapsed down into a simple constant.
|
||||||
|
|
||||||
|
### fold_constants
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
Looks for any sub-graphs within the model that always evaluate to constant
|
||||||
|
expressions, and replaces them with those constants. This optimization is always
|
||||||
|
executed at run-time after the graph is loaded, so running it offline first
|
||||||
|
won't help latency, but it can simplify the graph and so make further processing
|
||||||
|
easier. It's often useful to call this with `fold_constants(ignore_errors=true)`
|
||||||
|
to continue on past transient errors, since this is just an optimization phase.
|
||||||
|
|
||||||
|
### fold_old_batch_norms
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
In the early days of TensorFlow, batch normalization was implemented using a
|
||||||
|
single monolithic `BatchNormWithGlobalNormalization` op. In modern versions,
|
||||||
|
adding batch normalization from Python will give you a series of smaller math
|
||||||
|
ops instead, to achieve the same effect without special-purpose code. If you
|
||||||
|
have a graph that uses the older-style, this transform will recognize and
|
||||||
|
optimize those ops for inference, in the same way that the
|
||||||
|
[fold_batch_norms](#fold_batch_norms) transform does for the new approach.
|
||||||
|
|
||||||
|
### fuse_convolutions
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
For graphs that use ResizeBilinear or MirrorPad ops before convolutions,
|
||||||
|
typically to scale up in the later stages of an image style transfer model for
|
||||||
|
example, it can save on memory usage and latency to combine the spatial
|
||||||
|
transformations with the convolution's im2col patch generation. This transform
|
||||||
|
looks out for that particular pattern of ops and replaces them with a fused
|
||||||
|
version that combines the resizing and padding with the convolution.
|
||||||
|
|
||||||
|
### merge_duplicate_nodes
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
If there are Const nodes with the same types and contents, or nodes with the
|
||||||
|
same inputs and attributes, this transform will merge them together. It can be
|
||||||
|
useful when you want to cut down the number of nodes in a graph that has a lot
|
||||||
|
of redundancy, and is always run as part of [quantize_nodes](#quantize_nodes)
|
||||||
|
since the processing there can introduce duplicates of constants that are used
|
||||||
|
in the quantize/dequantize process.
|
||||||
|
|
||||||
|
### obsfucate_names
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
Replaces all node's names with short generated ids, other than the inputs and
|
||||||
|
outputs. This also updates all references within the graph so that the structure
|
||||||
|
is preserved. This can be useful if you want to shrink the file size, or if you
|
||||||
|
want to make it harder to understand the architecture of your model before
|
||||||
|
releasing it.
|
||||||
|
|
||||||
|
### quantize_nodes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* input_min: The lowest float value for any quantized placeholder inputs.
|
||||||
|
* input_max: The highest float value for any quantized placeholder inputs. If
|
||||||
|
both input_min and input_max are set, then any float placeholders in the
|
||||||
|
graph will be replaced with quantized versions, and consts will be created
|
||||||
|
to pass the range to subsequent operations.
|
||||||
|
* fallback_min: The lowest float value to use for requantizing activation
|
||||||
|
layers.
|
||||||
|
* fallback_max: The highest float value to use for requantizing activation
|
||||||
|
layers. If both fallback_min and fallback_max are set, then instead of using
|
||||||
|
RequantizationRange ops ro figure out the useful range dynamically when
|
||||||
|
converting the 32-bit output of ops like QuantizedConv2D and
|
||||||
|
QuantizedBiasAdd, hardwired consts with these values will be used instead.
|
||||||
|
This can help performance, if you know the range of your activation layers
|
||||||
|
ahead of time.
|
||||||
|
|
||||||
|
Replaces any calculation nodes with their eight-bit equivalents, and adds in
|
||||||
|
conversion layers to allow remaining float operations to interoperate. This is
|
||||||
|
one of the most complex transforms, and involves multiple passes and a lot of
|
||||||
|
rewriting. It's also still an active area of research, so results may vary
|
||||||
|
depending on the platform and operations you're using in your model. You should
|
||||||
|
run [quantize_weights](#quantize_weights) first to ensure your Const ops are in
|
||||||
|
eight-bit form.
|
||||||
|
|
||||||
|
### quantize_weights
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
Converts any large (more than 15 element) float Const op into an eight-bit
|
||||||
|
equivalent, followed by a float conversion op so that the result is usable by
|
||||||
|
subsequent nodes. This is mostly useful for [shrinking file
|
||||||
|
sizes](#shrinking-file-size), but also helps with the more advanced
|
||||||
|
[quantize_nodes](#quantize_nodes) transform.
|
||||||
|
|
||||||
|
### remove_attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* attribute_name: Name of the attribute you want to remove.
|
||||||
|
* op_name: Optional name of a single op to restrict the removal to.
|
||||||
|
|
||||||
|
Deletes the given attribute from either all nodes, or just the one specified in
|
||||||
|
`op_name`. This can be a dangerous transform since it's easy to leave your graph
|
||||||
|
in an invalid state if you remove a required attribute. It can be useful in
|
||||||
|
special circumstances though.
|
||||||
|
|
||||||
|
### remove_device
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
All ops can have a hardware device specified. This can be a problem when you're
|
||||||
|
loading a graph on a different system than the model was trained on, since some
|
||||||
|
specified devices may not be available. In order to work with graphs like these,
|
||||||
|
you can run this transform to wipe the slate clean and delete the device
|
||||||
|
specifier from all ops.
|
||||||
|
|
||||||
|
### remove_nodes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* op: The name of the op you want to remove. Can be repeated to remove
|
||||||
|
multiple ops.
|
||||||
|
|
||||||
|
This is a potentially dangerous transform that looks for single-input,
|
||||||
|
single-output ops with the given names, removes them from the graph, and rewires
|
||||||
|
all inputs that use to pull from them to pull from the preceding node instead.
|
||||||
|
This is most useful for getting rid of ops like `CheckNumerics` that are useful
|
||||||
|
during training but just complicate the graph and increase latency during
|
||||||
|
inference. It's dangerous because it's possible that removing some ops may
|
||||||
|
change the output of your graph, so make sure you check the overall accuracy
|
||||||
|
after using this.
|
||||||
|
|
||||||
|
### rename_attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* old_attribute_name: Current name of the attribute you want to rename.
|
||||||
|
* new_attribute_name: Name that you want the attribute to have now.
|
||||||
|
* op_name: If this is set, only change attributes for a given op type,
|
||||||
|
otherwise apply to all nodes with attribute names that match.
|
||||||
|
|
||||||
|
Changes the name of the given attribute. This is often useful for upgrading
|
||||||
|
graph files as op definitions change over versions, since the renaming is often
|
||||||
|
enough to deal with minor changes.
|
||||||
|
|
||||||
|
### rename_op
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* old_op_name: Current name of the operation.
|
||||||
|
* new_op_name: Name to change to.
|
||||||
|
|
||||||
|
Finds all ops with the given name, and changes them to the new one. This can be
|
||||||
|
useful for version upgrading if the changes between ops are minor apart from the
|
||||||
|
name.
|
||||||
|
|
||||||
|
### round_weights
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* num_steps: How many unique values to use in each buffer.
|
||||||
|
|
||||||
|
Rounds all float values in large Const ops (more than 15 elements) to the given
|
||||||
|
number of steps. The unique values are chosen per buffer by linearly allocating
|
||||||
|
between the largest and smallest values present. This is useful when you'll be
|
||||||
|
deploying on mobile, and you want a model that will compress effectively. See
|
||||||
|
[shrinking file size](#shrinking-file-size) for more details.
|
||||||
|
|
||||||
|
### sort_by_execution_order
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
Arranges the nodes in the GraphDef in topological order, so that the inputs of
|
||||||
|
any given node are always earlier than the node itself. This is especially
|
||||||
|
useful when you're targeting a minimal inference engine, since you can just
|
||||||
|
execute the nodes in the given order knowing that the inputs will be computed
|
||||||
|
before they're needed.
|
||||||
|
|
||||||
|
### strip_unused_nodes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* type: Default type for any new Placeholder nodes generated, for example
|
||||||
|
int32, float, quint8.
|
||||||
|
* shape: Default shape for any new Placeholder nodes generated, as
|
||||||
|
comma-separated dimensions. For example shape="1,299,299,3". The double
|
||||||
|
quotes are important, since otherwise the commas will be taken as argument
|
||||||
|
separators.
|
||||||
|
* name: Identifier for the placeholder arguments.
|
||||||
|
* type_for_name: What type to use for the previously-given name.
|
||||||
|
* shape_for_name: What shape to use for the previously-given name.
|
||||||
|
|
||||||
|
Removes all nodes not used in calculated the layers given in `--outputs`, fed by
|
||||||
|
`--inputs`. This is often useful for removing training-only nodes like
|
||||||
|
save-and-restore or summary ops. It's also handy for solving the [missing kernel
|
||||||
|
errors problem](#fixing-missing-kernel-errors-on-mobile) when there are decode
|
||||||
|
or other ops you don't need in the inference path.
|
||||||
|
|
||||||
|
The biggest complication is that it sometimes has to create new Placeholder ops,
|
||||||
|
so there are options to control their characteristics. This will happen if you
|
||||||
|
bypass a DecodeJpeg op by specifying an input layer deeper in the network, for
|
||||||
|
example, so you can pass in a raw image array instead of an encoded string as an
|
||||||
|
input. The decode op will be removed, together with the Placeholder that fed it,
|
||||||
|
but a new Placeholder is needed for the input layer you specify. The type and
|
||||||
|
shape arguments let you control the attributes of any new Placeholders that are
|
||||||
|
created. Plain `type` and `shape` set global defaults, but if you have different
|
||||||
|
inputs with varying characteristics, you'll need to pass in a list of arguments
|
||||||
|
where the preceding name specifies what layer each applies to. For example, if
|
||||||
|
you had two inputs in1 and in2, you could call `strip_unused_node(name=in1,
|
||||||
|
type_for_name=int32, shape_for_name="2,3", name=in2, type_for_name=float,
|
||||||
|
shape_for_name="1,10,10,3")`.
|
||||||
|
|
||||||
|
## Writing Your Own Transforms
|
||||||
|
|
||||||
|
The Graph Transform Tool is designed to make it as easy as possible to create
|
||||||
|
your own optimization, modification, and pre-processing transforms. At their
|
||||||
|
heart, all of the transforms take in a valid GraphDef, make some changes, and
|
||||||
|
output a new GraphDef. Each GraphDef is just a list of NodeDefs, each defining
|
||||||
|
one node in the graph and its connections. You can find more information on the
|
||||||
|
format at [this guide to TensorFlow model
|
||||||
|
files](https://www.tensorflow.org/versions/master/how_tos/tool_developers/index.html),
|
||||||
|
but for a simple example take a look at
|
||||||
|
[tensorflow/tools/graph_transforms/rename_op.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/rename_op.cc),
|
||||||
|
which implements the [rename_op](#rename_op) transform:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
Status RenameOp(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
if (!context.params.count("old_op_name") ||
|
||||||
|
(context.params.at("old_op_name").size() != 1) ||
|
||||||
|
!context.params.count("new_op_name") ||
|
||||||
|
(context.params.at("new_op_name").size() != 1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"remove_nodes expects exactly one 'old_op_name' and 'new_op_name' "
|
||||||
|
"argument, e.g. rename_op(old_op_name=Mul, new_op_name=Multiply)");
|
||||||
|
}
|
||||||
|
|
||||||
|
const string old_op_name = context.params.at("old_op_name")[0];
|
||||||
|
const string new_op_name = context.params.at("new_op_name")[0];
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
if (node.op() == old_op_name) {
|
||||||
|
new_node->set_op(new_op_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("rename_op", RenameOp);
|
||||||
|
```
|
||||||
|
|
||||||
|
The heart of this transform is the loop through the input_graph_def's nodes. We
|
||||||
|
go through each op, add a new one to the output, copy the original's contents,
|
||||||
|
and then change the op over if it matches the parameters. There's a standard set
|
||||||
|
of parameters for every transform, so they all take in a GraphDef and context,
|
||||||
|
and write out into a new GraphDef. The registration macro at the bottom lets the
|
||||||
|
tool know what function to call when it finds the `rename_op` string in a
|
||||||
|
transforms list.
|
||||||
|
|
||||||
|
### Transform Functions
|
||||||
|
|
||||||
|
The standard signature that all transform functions have is defined as
|
||||||
|
`TransformFunc`, which takes in an input GraphDef, a `TransformFuncContext`
|
||||||
|
containing environment information, writes to an output GraphDef, and returns a
|
||||||
|
Status indicating whether the transform succeeded.
|
||||||
|
|
||||||
|
The `TransformFuncContext` has a list of the inputs and outputs for the graph,
|
||||||
|
and the [parameter arguments](#parameters) that were passed into the transform
|
||||||
|
by the user.
|
||||||
|
|
||||||
|
If you write a function that matches this signature, and [register
|
||||||
|
it](#registration), the graph transform tool will take care of calling it.
|
||||||
|
|
||||||
|
### Pattern Syntax
|
||||||
|
|
||||||
|
The `rename_op` example only needs to look at a single node at a time, but one
|
||||||
|
of the most common needs is to modify small sub-graphs within a model. To make
|
||||||
|
this easy, the Graph Transform Tool provides the `OpTypePattern` syntax. This is
|
||||||
|
a simple and compact way to specify patterns of nodes that you want to look for.
|
||||||
|
For example, if you want all Conv2D nodes that have a constant as their second
|
||||||
|
input, you would set up a pattern like this, using C++ initializer lists to
|
||||||
|
populate the structure:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
OpTypePattern conv_pattern({"Conv2D", {{"*"}, {"Const"}}});
|
||||||
|
```
|
||||||
|
|
||||||
|
It can be easier to visualize these initializers using indentation to show the
|
||||||
|
tree structure more clearly:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
OpTypePattern conv_pattern({
|
||||||
|
"Conv2D",
|
||||||
|
{
|
||||||
|
{"*"},
|
||||||
|
{"Const"}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
In plain English this is saying, a Conv2D op with two inputs, the first of which
|
||||||
|
is any op type, and the second is a Const op.
|
||||||
|
|
||||||
|
The op field can either contain a single "*", which means match any op type, one
|
||||||
|
op type (for example "Const"), or a set of op types separated by `|` symbols
|
||||||
|
(for example "Conv2D|MatMul|BiasAdd"). General regex patterns are not supported,
|
||||||
|
just these special cases.
|
||||||
|
|
||||||
|
You can think of these patterns as very limited regular expressions designed to
|
||||||
|
pick out sub-trees in graphs. They are deliberately very constrained to the kind
|
||||||
|
of things we commonly find ourselves needing to do, to make creating and
|
||||||
|
debugging as straightforward as possible.
|
||||||
|
|
||||||
|
Here's a much more complex example, from the [quantize_nodes](#quantize_nodes)
|
||||||
|
transform:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
{"QuantizeV2",
|
||||||
|
{
|
||||||
|
{"Dequantize"},
|
||||||
|
{"Min",
|
||||||
|
{
|
||||||
|
{"Reshape",
|
||||||
|
{
|
||||||
|
{"Dequantize"},
|
||||||
|
{"Const"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Const"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Max",
|
||||||
|
{
|
||||||
|
{"Reshape",
|
||||||
|
{
|
||||||
|
{"Dequantize"},
|
||||||
|
{"Const"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Const"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This is looking for QuantizeV2 nodes, with three inputs, the first of which is a
|
||||||
|
Dequantize, the second is a Min that ultimately pulls from a Dequantize, and the
|
||||||
|
third is a Max which does the same. We know the end result of this sub-graph is
|
||||||
|
a no-op, since it's just turning an eight-bit buffer into float, and then
|
||||||
|
immediately converting it back to eight-bits, so if we look for this pattern and
|
||||||
|
remove it we can optimize the graph without changing the result.
|
||||||
|
|
||||||
|
### ReplaceMatchingOpTypes
|
||||||
|
|
||||||
|
It's very common to want to find all occurrences of a particular sub-graph in a
|
||||||
|
model, and replace them all with a different sub-graph that keeps the same local
|
||||||
|
input and output connections. For example with
|
||||||
|
[fuse_convolutions](#fuse_convolutions), we needed to find all Conv2D ops that
|
||||||
|
read their inputs from BilinearResizes, and replace those combinations with a
|
||||||
|
single FusedResizeAndPadConv2D op, but without affecting other ops.
|
||||||
|
|
||||||
|
To make that sort of transformation easy, we created the
|
||||||
|
`ReplaceMatchingOpTypes` helper. This takes in a graph, an `OpTypePattern`
|
||||||
|
defining the sub-graph to look for, and a callback function to run for every
|
||||||
|
occurrence it finds. The job of this callback function is to look at the
|
||||||
|
`NodeMatch` that contains information about the current sub-graph, and return a
|
||||||
|
new sub-graph in the new_nodes list that will be used to replace the old
|
||||||
|
sub-graph.
|
||||||
|
|
||||||
|
You can see how it's used in practice in the
|
||||||
|
[fuse_convolutions](#fuse_convolutions) code:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"Conv2D",
|
||||||
|
{
|
||||||
|
{"ResizeBilinear"},
|
||||||
|
{"*"}
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& conv_node = match.node;
|
||||||
|
const NodeDef& resize_node = match.inputs[0].node;
|
||||||
|
const NodeDef& weights_node = match.inputs[1].node;
|
||||||
|
|
||||||
|
// We'll be reusing the old weights.
|
||||||
|
new_nodes->push_back(weights_node);
|
||||||
|
|
||||||
|
// Create a 'no-op' mirror padding node that has no effect.
|
||||||
|
NodeDef pad_dims_node;
|
||||||
|
pad_dims_node.set_op("Const");
|
||||||
|
pad_dims_node.set_name(conv_node.name() + "_dummy_paddings");
|
||||||
|
SetNodeAttr("dtype", DT_INT32, &pad_dims_node);
|
||||||
|
SetNodeTensorAttr<int32>("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
&pad_dims_node);
|
||||||
|
new_nodes->push_back(pad_dims_node);
|
||||||
|
|
||||||
|
// Set up the new fused version of the convolution op.
|
||||||
|
NodeDef fused_conv;
|
||||||
|
fused_conv.set_op("FusedResizeAndPadConv2D");
|
||||||
|
fused_conv.set_name(match.node.name());
|
||||||
|
AddNodeInput(resize_node.input(0), &fused_conv);
|
||||||
|
AddNodeInput(resize_node.input(1), &fused_conv);
|
||||||
|
AddNodeInput(pad_dims_node.name(), &fused_conv);
|
||||||
|
AddNodeInput(conv_node.input(1), &fused_conv);
|
||||||
|
CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
|
||||||
|
&fused_conv);
|
||||||
|
SetNodeAttr("mode", "REFLECT", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
|
||||||
|
new_nodes->push_back(fused_conv);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
```
|
||||||
|
|
||||||
|
Here you can see we define the pattern to look for, and in the callback function
|
||||||
|
use information from each of the nodes in the old sub-graph to create a new
|
||||||
|
fused node. We also copy over the old weights input node so that isn't lost.
|
||||||
|
|
||||||
|
There are a few things to know about the `ReplaceMatchingOpTypes` function:
|
||||||
|
|
||||||
|
* All of the nodes in any matching sub-graphs are removed from the new graph
|
||||||
|
created by the function. If any of them are needed, it's the callback
|
||||||
|
function's responsibility to add them back in. There's a `CopyOriginalMatch`
|
||||||
|
convenience call that will copy over all of the original nodes if you decide
|
||||||
|
you don't actually want to modify a particular sub-graph.
|
||||||
|
|
||||||
|
* Nodes will never appear in more than one matched sub-graph. This is to
|
||||||
|
ensure that sub-trees are only replaced once, but it may mean that some
|
||||||
|
sub-graphs aren't spotted if they overlap with earlier matches.
|
||||||
|
|
||||||
|
* The calling framework tries to ensure that the graph remains sane, by
|
||||||
|
looking at the new_nodes that are returned and making sure that no nodes
|
||||||
|
which are needed as inputs by nodes outside the sub-graph are removed. These
|
||||||
|
important nodes are listed in the `output_nodes` argument that's passed into
|
||||||
|
each replacement function call. You can disable this checking by setting
|
||||||
|
`allow_inconsistencies` to true in the options, but otherwise any
|
||||||
|
replacements that break the graph constraints will be cancelled. If you do
|
||||||
|
allow inconsistencies, it's your transform's responsibility to fix them up
|
||||||
|
before you return your final result. Functions like `RenameNodeInputs` can
|
||||||
|
be useful if you are doing wholesale node renaming for example.
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
The arguments that are in parentheses after the transform name when the tool is
|
||||||
|
called are parsed and placed into the params member of the TransformFuncContext
|
||||||
|
that's given to each transform. For every named argument, there's a vector of
|
||||||
|
strings containing all the values that it was given, in the order they were
|
||||||
|
given. These are treated a bit like command-line parameters, and it's the
|
||||||
|
transform's responsibility to parse them into the data types it needs, and raise
|
||||||
|
errors by returning a bad Status if any of them are ill-formed.
|
||||||
|
|
||||||
|
As an example, here's a hypothetical transform call:
|
||||||
|
|
||||||
|
```
|
||||||
|
some_transform(foo=a, foo=b, bar=2, bob="1,2,3")
|
||||||
|
```
|
||||||
|
|
||||||
|
Here's what the std::map of strings looks like in the params member:
|
||||||
|
|
||||||
|
```
|
||||||
|
{{"foo", {"a", "b"}}, {"bar", {"2"}}, {"bob", {"1,2,3"}}}
|
||||||
|
```
|
||||||
|
|
||||||
|
The double quotes around the comma-separated argument to `bob` are important
|
||||||
|
because otherwise they'll be treated as separate arguments, and the parsing will
|
||||||
|
fail.
|
||||||
|
|
||||||
|
Here's an example of how [round_weights](#round_weights) reads its `num_steps`
|
||||||
|
parameter:
|
||||||
|
|
||||||
|
```C++
|
||||||
|
string num_steps_string;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
|
||||||
|
int32 num_steps;
|
||||||
|
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Couldn't interpret the num_steps argument to round_weights as a "
|
||||||
|
"number:",
|
||||||
|
num_steps_string);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Things to notice here are that you have to convert the string to an integer, and
|
||||||
|
if the conversion fails you need to raise a meaningful error through the status
|
||||||
|
result of the transform. We're also using a helper function which raises an
|
||||||
|
error if the parameter is present multiple times, and uses a default if the user
|
||||||
|
hasn't specified it.
|
||||||
|
|
||||||
|
### Function Libraries
|
||||||
|
|
||||||
|
A newer feature of TensorFlow is the ability to create libraries of functions as
|
||||||
|
part of graphs. These are a bit like templates, which define macro operations in
|
||||||
|
terms of smaller components, which can then be instantiated with different input
|
||||||
|
and output connections inside the graph just like regular ops. Right now the
|
||||||
|
graph transform tool just copies these libraries between the input and output
|
||||||
|
graphs, but it's likely that more complex operations will be supported on them
|
||||||
|
in the future.
|
||||||
|
|
||||||
|
### Registering
|
||||||
|
|
||||||
|
The Graph Transform Tool associates names of transforms with the code to
|
||||||
|
implement them using the `REGISTER_GRAPH_TRANSFORM()` macro. This takes a string
|
||||||
|
and a function, and automagically registers the transform with the tool. You
|
||||||
|
will need to watch out for a few things though:
|
||||||
|
|
||||||
|
* Because it's using global C++ objects in each file under the hood, the
|
||||||
|
linker can sometimes strip them out and lose the registration. In Bazel you
|
||||||
|
need to make sure you're linking any new transforms in as libraries, and use
|
||||||
|
the `alwayslink` flag in your `cc_binary` call.
|
||||||
|
|
||||||
|
* You should be able to create your own copy of the transform_graph tool by
|
||||||
|
linking against the transform_graph_main_lib library in
|
||||||
|
tensorflow/tools/graph_transforms/BUILD. This contains all the `main()`
|
||||||
|
logic to parse command line arguments and call transforms.
|
108
tensorflow/tools/graph_transforms/fold_batch_norms.cc
Normal file
108
tensorflow/tools/graph_transforms/fold_batch_norms.cc
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Converts Conv2D ops followed by column-wise Muls into equivalent ops with the
|
||||||
|
// Mul baked into the convolution weights, to save computation during inference.
|
||||||
|
Status FoldBatchNorms(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"Mul", // mul_node
|
||||||
|
{
|
||||||
|
{"Conv2D", // conv_node
|
||||||
|
{
|
||||||
|
{"*"}, // input_node
|
||||||
|
{"Const"}, // weights_node
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Const"}, // mul_values_node
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& mul_node = match.node;
|
||||||
|
const NodeDef& conv_node = match.inputs[0].node;
|
||||||
|
const NodeDef& input_node = match.inputs[0].inputs[0].node;
|
||||||
|
const NodeDef& weights_node = match.inputs[0].inputs[1].node;
|
||||||
|
const NodeDef& mul_values_node = match.inputs[1].node;
|
||||||
|
|
||||||
|
Tensor weights = GetNodeTensorAttr(weights_node, "value");
|
||||||
|
Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value");
|
||||||
|
|
||||||
|
// Make sure all the inputs really are vectors, with as many entries as
|
||||||
|
// there are columns in the weights.
|
||||||
|
const int64 weights_cols = weights.shape().dim_size(3);
|
||||||
|
if ((mul_values.shape().dims() != 1) ||
|
||||||
|
(mul_values.shape().dim_size(0) != weights_cols)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Mul constant input to batch norm has bad shape: ",
|
||||||
|
mul_values.shape().DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply the original weights by the scale vector.
|
||||||
|
auto weights_matrix = weights.flat_inner_dims<float>();
|
||||||
|
Tensor scaled_weights(DT_FLOAT, weights.shape());
|
||||||
|
auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
|
||||||
|
for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
|
||||||
|
for (int64 col = 0; col < weights_cols; ++col) {
|
||||||
|
scaled_weights_matrix(row, col) =
|
||||||
|
weights_matrix(row, col) * mul_values.flat<float>()(col);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the new nodes.
|
||||||
|
NodeDef scaled_weights_node;
|
||||||
|
scaled_weights_node.set_op("Const");
|
||||||
|
scaled_weights_node.set_name(weights_node.name());
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
|
||||||
|
SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
|
||||||
|
new_nodes->push_back(scaled_weights_node);
|
||||||
|
|
||||||
|
new_nodes->push_back(input_node);
|
||||||
|
|
||||||
|
NodeDef new_conv_node;
|
||||||
|
new_conv_node.CopyFrom(conv_node);
|
||||||
|
new_conv_node.set_name(mul_node.name());
|
||||||
|
new_nodes->push_back(new_conv_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
*output_graph_def = replaced_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("fold_batch_norms", FoldBatchNorms);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
93
tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
Normal file
93
tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status FoldBatchNorms(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class FoldBatchNormsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestFoldBatchNorms() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
|
||||||
|
test::FillValues<float>(&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
|
||||||
|
test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
|
||||||
|
Output mul_values_op = Const(root.WithOpName("mul_values"),
|
||||||
|
Input::Initializer(mul_values_data));
|
||||||
|
|
||||||
|
Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef fused_graph_def;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
|
||||||
|
std::vector<Tensor> fused_outputs;
|
||||||
|
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
|
||||||
|
|
||||||
|
for (const NodeDef& node : fused_graph_def.node()) {
|
||||||
|
EXPECT_NE("Mul", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
@ -95,17 +95,16 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
|
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
|
||||||
const std::vector<string>& inputs,
|
const TransformFuncContext& context,
|
||||||
const std::vector<string>& outputs,
|
|
||||||
GraphDef* output_graph_def) {
|
GraphDef* output_graph_def) {
|
||||||
std::map<string, const NodeDef*> node_map;
|
std::map<string, const NodeDef*> node_map;
|
||||||
MapNamesToNodes(input_graph_def, &node_map);
|
MapNamesToNodes(input_graph_def, &node_map);
|
||||||
|
|
||||||
std::map<string, bool> used_nodes;
|
std::map<string, bool> used_nodes;
|
||||||
for (const string& input : inputs) {
|
for (const string& input : context.input_names) {
|
||||||
used_nodes[input] = true;
|
used_nodes[input] = true;
|
||||||
}
|
}
|
||||||
std::vector<string> current_nodes = outputs;
|
std::vector<string> current_nodes = context.output_names;
|
||||||
while (!current_nodes.empty()) {
|
while (!current_nodes.empty()) {
|
||||||
std::vector<string> next_nodes;
|
std::vector<string> next_nodes;
|
||||||
for (const string& node_name : current_nodes) {
|
for (const string& node_name : current_nodes) {
|
||||||
@ -134,9 +133,10 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts any sub-graphs that can be resolved into constant expressions into
|
||||||
|
// single Const ops.
|
||||||
Status FoldConstants(const GraphDef& input_graph_def,
|
Status FoldConstants(const GraphDef& input_graph_def,
|
||||||
const std::vector<string>& inputs,
|
const TransformFuncContext& context,
|
||||||
const std::vector<string>& outputs,
|
|
||||||
GraphDef* output_graph_def) {
|
GraphDef* output_graph_def) {
|
||||||
// Some older GraphDefs have saved _output_shapes attributes which are out of
|
// Some older GraphDefs have saved _output_shapes attributes which are out of
|
||||||
// date and cause import errors, so clean them up first.
|
// date and cause import errors, so clean them up first.
|
||||||
@ -148,20 +148,24 @@ Status FoldConstants(const GraphDef& input_graph_def,
|
|||||||
ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
|
ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
|
||||||
DeviceAttributes device_attributes;
|
DeviceAttributes device_attributes;
|
||||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||||
&input_graph, inputs, outputs, {}, device_attributes));
|
&input_graph, context.input_names, context.output_names, {},
|
||||||
if (!DoConstantFolding(ConstantFoldingOptions(), nullptr, Env::Default(),
|
device_attributes));
|
||||||
nullptr, &input_graph)) {
|
bool was_mutated;
|
||||||
return errors::InvalidArgument("Constant folding failed");
|
TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
|
||||||
}
|
ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,
|
||||||
|
&was_mutated));
|
||||||
GraphDef folded_graph_def;
|
GraphDef folded_graph_def;
|
||||||
input_graph.ToGraphDef(&folded_graph_def);
|
input_graph.ToGraphDef(&folded_graph_def);
|
||||||
GraphDef send_recvs_replaced;
|
GraphDef send_recvs_replaced;
|
||||||
TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def, inputs,
|
TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def,
|
||||||
outputs, &send_recvs_replaced));
|
context.input_names, context.output_names,
|
||||||
TF_RETURN_IF_ERROR(RemoveUnusedNodes(send_recvs_replaced, inputs, outputs,
|
&send_recvs_replaced));
|
||||||
output_graph_def));
|
TF_RETURN_IF_ERROR(
|
||||||
|
RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace graph_transforms {
|
namespace graph_transforms {
|
||||||
@ -27,15 +28,13 @@ namespace graph_transforms {
|
|||||||
// the names of all the nodes that data is fed into, or read out of, when the
|
// the names of all the nodes that data is fed into, or read out of, when the
|
||||||
// graph is actually run.
|
// graph is actually run.
|
||||||
Status FoldConstants(const GraphDef& input_graph_def,
|
Status FoldConstants(const GraphDef& input_graph_def,
|
||||||
const std::vector<string>& inputs,
|
const TransformFuncContext& context,
|
||||||
const std::vector<string>& outputs,
|
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
// Analyzes which nodes are used for the given set of inputs and outputs, and
|
// Analyzes which nodes are used for the given set of inputs and outputs, and
|
||||||
// returns a copy of the graph with any that aren't used removed.
|
// returns a copy of the graph with any that aren't used removed.
|
||||||
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
|
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
|
||||||
const std::vector<string>& inputs,
|
const TransformFuncContext& context,
|
||||||
const std::vector<string>& outputs,
|
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
|
@ -82,12 +82,13 @@ class ConstantFoldingTest : public ::testing::Test {
|
|||||||
TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors));
|
TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors));
|
||||||
|
|
||||||
GraphDef folded_graph_def;
|
GraphDef folded_graph_def;
|
||||||
std::vector<string> input_names;
|
graph_transforms::TransformFuncContext context;
|
||||||
for (const std::pair<string, Tensor>& input : inputs) {
|
for (const std::pair<string, Tensor>& input : inputs) {
|
||||||
input_names.push_back(input.first);
|
context.input_names.push_back(input.first);
|
||||||
}
|
}
|
||||||
TF_ASSERT_OK(graph_transforms::FoldConstants(graph_def, input_names,
|
context.output_names = outputs;
|
||||||
outputs, &folded_graph_def));
|
TF_ASSERT_OK(
|
||||||
|
graph_transforms::FoldConstants(graph_def, context, &folded_graph_def));
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::Session> folded_session(
|
std::unique_ptr<tensorflow::Session> folded_session(
|
||||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||||
@ -187,7 +188,7 @@ class ConstantFoldingTest : public ::testing::Test {
|
|||||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
GraphDef result_graph_def;
|
GraphDef result_graph_def;
|
||||||
TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
|
TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
|
||||||
graph_def, {"placeholder"}, {"output"}, &result_graph_def));
|
graph_def, {{"placeholder"}, {"output"}}, &result_graph_def));
|
||||||
|
|
||||||
std::map<string, const NodeDef*> node_map;
|
std::map<string, const NodeDef*> node_map;
|
||||||
graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
|
graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
|
||||||
|
@ -1,110 +0,0 @@
|
|||||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Utility that transforms a model with subgraphs that evaluate to constant
|
|
||||||
// functions into the equivalent model with those subgraphs replaced by Const
|
|
||||||
// nodes. This simplifies the graph, and makes some further transformations
|
|
||||||
// easier to perform. It's often useful to run the freeze_graph tool on the
|
|
||||||
// input graph beforehand to ensure variables have been transformed to Consts.
|
|
||||||
//
|
|
||||||
// bazel-bin/tensorflow/tools/graph_transforms/fold_constants_tool \
|
|
||||||
// --in_graph=graph_def.pb \
|
|
||||||
// --out_graph=folded_graph_def.pb \
|
|
||||||
// --inputs=input1,input2 \
|
|
||||||
// --outputs=output1,output2
|
|
||||||
//
|
|
||||||
// Parameters:
|
|
||||||
// in_graph - name of a file with a frozen GraphDef proto in binary format.
|
|
||||||
// out_graph - name of the output file to save the folded version to.
|
|
||||||
// inputs - layer names of the nodes that will be fed data.
|
|
||||||
// outputs - layer names of the nodes that will be read from after running.
|
|
||||||
|
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
|
||||||
#include "tensorflow/core/platform/env.h"
|
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
|
|
||||||
string in_graph = "";
|
|
||||||
string out_graph = "";
|
|
||||||
string inputs_string = "";
|
|
||||||
string outputs_string = "";
|
|
||||||
std::vector<Flag> flag_list = {
|
|
||||||
Flag("in_graph", &in_graph, "input graph file name"),
|
|
||||||
Flag("out_graph", &out_graph, "output graph file name"),
|
|
||||||
Flag("inputs", &inputs_string, "inputs"),
|
|
||||||
Flag("outputs", &outputs_string, "outputs"),
|
|
||||||
};
|
|
||||||
string usage = Flags::Usage(argv[0], flag_list);
|
|
||||||
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
|
|
||||||
// We need to call this to set up global state for TensorFlow.
|
|
||||||
port::InitMain(argv[0], &argc, &argv);
|
|
||||||
if (!parse_result) {
|
|
||||||
LOG(ERROR) << usage;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (argc > 1) {
|
|
||||||
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (in_graph.empty()) {
|
|
||||||
LOG(ERROR) << "in_graph graph can't be empty";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
if (out_graph.empty()) {
|
|
||||||
LOG(ERROR) << "out_graph graph can't be empty";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
std::vector<string> inputs = str_util::Split(inputs_string, ',');
|
|
||||||
std::vector<string> outputs = str_util::Split(outputs_string, ',');
|
|
||||||
|
|
||||||
GraphDef graph_def;
|
|
||||||
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
|
|
||||||
if (!load_status.ok()) {
|
|
||||||
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
|
|
||||||
<< load_status.error_message();
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphDef folded_graph_def;
|
|
||||||
Status folding_result = graph_transforms::FoldConstants(
|
|
||||||
graph_def, inputs, outputs, &folded_graph_def);
|
|
||||||
if (!folding_result.ok()) {
|
|
||||||
LOG(ERROR) << "Folding failed " << folding_result.error_message();
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status save_status =
|
|
||||||
WriteBinaryProto(Env::Default(), out_graph, folded_graph_def);
|
|
||||||
if (!save_status.ok()) {
|
|
||||||
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
|
|
||||||
<< save_status.error_message();
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
int main(int argc, char* argv[]) {
|
|
||||||
return tensorflow::ParseFlagsAndConvertGraph(argc, argv);
|
|
||||||
}
|
|
193
tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
Normal file
193
tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
namespace {
|
||||||
|
// Ensures the tensor is the expected shape.
|
||||||
|
Status ErrorIfNotVector(const Tensor& input, const string& input_name,
|
||||||
|
int expected_width) {
|
||||||
|
if ((input.shape().dims() != 1) ||
|
||||||
|
(input.shape().dim_size(0) != expected_width)) {
|
||||||
|
return errors::InvalidArgument(input_name,
|
||||||
|
" input to batch norm has bad shape: ",
|
||||||
|
input.shape().DebugString());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Finds monolithic batch norm ops (as used in early versions of TensorFlow) and
|
||||||
|
// converts them into premultiplied weight inputs to convolutions.
|
||||||
|
Status FoldOldBatchNorms(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef current_graph_def = input_graph_def;
|
||||||
|
// We have to do several passes to catch all the old BN nodes, since many of
|
||||||
|
// them may share inputs and so be excluded from replacement in one pass.
|
||||||
|
bool did_graph_change;
|
||||||
|
do {
|
||||||
|
did_graph_change = false;
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
current_graph_def, // clang-format off
|
||||||
|
{"BatchNormWithGlobalNormalization", // batch_norm_node
|
||||||
|
{
|
||||||
|
{"Conv2D", // conv_node
|
||||||
|
{
|
||||||
|
{"*"}, // input_node
|
||||||
|
{"Const"}, // weights_node
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Const"}, // mean_node
|
||||||
|
{"Const"}, // variance_node
|
||||||
|
{"Const"}, // beta_node
|
||||||
|
{"Const"}, // gamma_node
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[&did_graph_change](const NodeMatch& match,
|
||||||
|
const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& batch_norm_node = match.node;
|
||||||
|
CHECK_EQ("BatchNormWithGlobalNormalization", batch_norm_node.op());
|
||||||
|
const NodeDef& conv_node = match.inputs[0].node;
|
||||||
|
CHECK_EQ("Conv2D", conv_node.op());
|
||||||
|
const NodeDef& input_node = match.inputs[0].inputs[0].node;
|
||||||
|
const NodeDef& weights_node = match.inputs[0].inputs[1].node;
|
||||||
|
CHECK_EQ("Const", weights_node.op());
|
||||||
|
const NodeDef& mean_node = match.inputs[1].node;
|
||||||
|
CHECK_EQ("Const", mean_node.op());
|
||||||
|
const NodeDef& variance_node = match.inputs[2].node;
|
||||||
|
CHECK_EQ("Const", variance_node.op());
|
||||||
|
const NodeDef& beta_node = match.inputs[3].node;
|
||||||
|
CHECK_EQ("Const", beta_node.op());
|
||||||
|
const NodeDef& gamma_node = match.inputs[4].node;
|
||||||
|
CHECK_EQ("Const", gamma_node.op());
|
||||||
|
|
||||||
|
// We have a set of vectors that we want to combine into a vector of
|
||||||
|
// scale values to apply column-wise to the weight input to the conv,
|
||||||
|
// and an offset vector that we'll apply to the output of the conv.
|
||||||
|
Tensor weights = GetNodeTensorAttr(weights_node, "value");
|
||||||
|
Tensor mean = GetNodeTensorAttr(mean_node, "value");
|
||||||
|
Tensor variance = GetNodeTensorAttr(variance_node, "value");
|
||||||
|
Tensor beta = GetNodeTensorAttr(beta_node, "value");
|
||||||
|
Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
|
||||||
|
const float variance_epsilon =
|
||||||
|
batch_norm_node.attr().at("variance_epsilon").f();
|
||||||
|
const bool scale_after_normalization =
|
||||||
|
batch_norm_node.attr().at("scale_after_normalization").b();
|
||||||
|
|
||||||
|
// Make sure all the inputs really are vectors, with as many entries
|
||||||
|
// as there are columns in the weights.
|
||||||
|
const int64 weights_cols = weights.shape().dim_size(3);
|
||||||
|
TF_RETURN_IF_ERROR(ErrorIfNotVector(mean, "Mean", weights_cols));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ErrorIfNotVector(variance, "Variance", weights_cols));
|
||||||
|
TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", weights_cols));
|
||||||
|
TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", weights_cols));
|
||||||
|
|
||||||
|
// Calculate the scale and offset values to apply.
|
||||||
|
std::vector<float> scale_values(weights_cols);
|
||||||
|
std::vector<float> offset_values(weights_cols);
|
||||||
|
if (scale_after_normalization) {
|
||||||
|
for (int i = 0; i < weights_cols; ++i) {
|
||||||
|
scale_values[i] =
|
||||||
|
(1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
|
||||||
|
gamma.flat<float>()(i);
|
||||||
|
offset_values[i] = 0.0f;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < weights_cols; ++i) {
|
||||||
|
scale_values[i] =
|
||||||
|
(1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
|
||||||
|
offset_values[i] = (-mean.flat<float>()(i) * scale_values[i]) +
|
||||||
|
beta.flat<float>()(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply the original weights by the scale vector.
|
||||||
|
auto weights_matrix = weights.flat_inner_dims<float>();
|
||||||
|
Tensor scaled_weights(DT_FLOAT, weights.shape());
|
||||||
|
auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
|
||||||
|
for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
|
||||||
|
for (int64 col = 0; col < weights_cols; ++col) {
|
||||||
|
scaled_weights_matrix(row, col) =
|
||||||
|
weights_matrix(row, col) * scale_values[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Figure out the remaining bias to add on.
|
||||||
|
Tensor bias_offset(DT_FLOAT, {weights_cols});
|
||||||
|
auto bias_offset_vector = bias_offset.flat<float>();
|
||||||
|
for (int64 col = 0; col < weights_cols; ++col) {
|
||||||
|
bias_offset_vector(col) = offset_values[col];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the new nodes.
|
||||||
|
NodeDef scaled_weights_node;
|
||||||
|
scaled_weights_node.set_op("Const");
|
||||||
|
scaled_weights_node.set_name(weights_node.name());
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
|
||||||
|
SetNodeTensorAttr<float>("value", scaled_weights,
|
||||||
|
&scaled_weights_node);
|
||||||
|
new_nodes->push_back(scaled_weights_node);
|
||||||
|
|
||||||
|
// The input and convolution can be copied straight over, since the
|
||||||
|
// name of the scaled weights constant is the same as the original.
|
||||||
|
new_nodes->push_back(input_node);
|
||||||
|
new_nodes->push_back(conv_node);
|
||||||
|
|
||||||
|
NodeDef bias_offset_node;
|
||||||
|
bias_offset_node.set_op("Const");
|
||||||
|
bias_offset_node.set_name(conv_node.name() + "_bn_offset");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
|
||||||
|
SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
|
||||||
|
new_nodes->push_back(bias_offset_node);
|
||||||
|
|
||||||
|
NodeDef bias_add_node;
|
||||||
|
bias_add_node.set_op("BiasAdd");
|
||||||
|
bias_add_node.set_name(batch_norm_node.name());
|
||||||
|
CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
|
||||||
|
AddNodeInput(conv_node.name(), &bias_add_node);
|
||||||
|
AddNodeInput(bias_offset_node.name(), &bias_add_node);
|
||||||
|
new_nodes->push_back(bias_add_node);
|
||||||
|
|
||||||
|
did_graph_change = true;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
current_graph_def = replaced_graph_def;
|
||||||
|
} while (did_graph_change);
|
||||||
|
*output_graph_def = current_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
128
tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
Normal file
128
tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status FoldOldBatchNorms(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class FoldOldBatchNormsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestFoldOldBatchNorms() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
|
||||||
|
test::FillValues<float>(&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
Tensor mean_data(DT_FLOAT, TensorShape({2}));
|
||||||
|
test::FillValues<float>(&mean_data, {10.0f, 20.0f});
|
||||||
|
Output mean_op =
|
||||||
|
Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
|
||||||
|
|
||||||
|
Tensor variance_data(DT_FLOAT, TensorShape({2}));
|
||||||
|
test::FillValues<float>(&variance_data, {0.25f, 0.5f});
|
||||||
|
Output variance_op = Const(root.WithOpName("variance_op"),
|
||||||
|
Input::Initializer(variance_data));
|
||||||
|
|
||||||
|
Tensor beta_data(DT_FLOAT, TensorShape({2}));
|
||||||
|
test::FillValues<float>(&beta_data, {0.1f, 0.6f});
|
||||||
|
Output beta_op =
|
||||||
|
Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
|
||||||
|
|
||||||
|
Tensor gamma_data(DT_FLOAT, TensorShape({2}));
|
||||||
|
test::FillValues<float>(&gamma_data, {1.0f, 2.0f});
|
||||||
|
Output gamma_op =
|
||||||
|
Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
// This is needed because we're trying to convert over a deprecated op which
|
||||||
|
// should only be present in older GraphDef files. Without this we see a
|
||||||
|
// deprecation error.
|
||||||
|
// This is justified because we're trying to test a tool that is expected to
|
||||||
|
// run on legacy files, to help users convert over to less problematic
|
||||||
|
// versions.
|
||||||
|
NodeDef batch_norm_node;
|
||||||
|
batch_norm_node.set_op("BatchNormWithGlobalNormalization");
|
||||||
|
batch_norm_node.set_name("output");
|
||||||
|
AddNodeInput("conv_op", &batch_norm_node);
|
||||||
|
AddNodeInput("mean_op", &batch_norm_node);
|
||||||
|
AddNodeInput("variance_op", &batch_norm_node);
|
||||||
|
AddNodeInput("beta_op", &batch_norm_node);
|
||||||
|
AddNodeInput("gamma_op", &batch_norm_node);
|
||||||
|
SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
|
||||||
|
SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
|
||||||
|
SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
|
||||||
|
*(original_graph_def.mutable_node()->Add()) = batch_norm_node;
|
||||||
|
original_graph_def.mutable_versions()->set_producer(8);
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef fused_graph_def;
|
||||||
|
TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
|
||||||
|
&fused_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
|
||||||
|
std::vector<Tensor> fused_outputs;
|
||||||
|
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
|
||||||
|
|
||||||
|
for (const NodeDef& node : fused_graph_def.node()) {
|
||||||
|
EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNorms) {
|
||||||
|
TestFoldOldBatchNorms();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
200
tensorflow/tools/graph_transforms/fuse_convolutions.cc
Normal file
200
tensorflow/tools/graph_transforms/fuse_convolutions.cc
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
Status FuseResizePadAndConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"Conv2D",
|
||||||
|
{
|
||||||
|
{"MirrorPad",
|
||||||
|
{
|
||||||
|
{"ResizeBilinear"},
|
||||||
|
{"*"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"*"}
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& conv_node = match.node;
|
||||||
|
const NodeDef& mirror_pad_node = match.inputs[0].node;
|
||||||
|
const NodeDef& weights_node = match.inputs[1].node;
|
||||||
|
const NodeDef& resize_node = match.inputs[0].inputs[0].node;
|
||||||
|
const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
|
||||||
|
|
||||||
|
// We'll be reusing the old weights and pad dimensions.
|
||||||
|
new_nodes->push_back(weights_node);
|
||||||
|
new_nodes->push_back(pad_dims_node);
|
||||||
|
|
||||||
|
// Set up the new fused version of the convolution op.
|
||||||
|
NodeDef fused_conv;
|
||||||
|
fused_conv.set_op("FusedResizeAndPadConv2D");
|
||||||
|
fused_conv.set_name(match.node.name());
|
||||||
|
AddNodeInput(resize_node.input(0), &fused_conv);
|
||||||
|
AddNodeInput(resize_node.input(1), &fused_conv);
|
||||||
|
AddNodeInput(mirror_pad_node.input(1), &fused_conv);
|
||||||
|
AddNodeInput(conv_node.input(1), &fused_conv);
|
||||||
|
CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
|
||||||
|
&fused_conv);
|
||||||
|
CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
|
||||||
|
new_nodes->push_back(fused_conv);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
*output_graph_def = replaced_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FuseResizeAndConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"Conv2D",
|
||||||
|
{
|
||||||
|
{"ResizeBilinear"},
|
||||||
|
{"*"}
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& conv_node = match.node;
|
||||||
|
const NodeDef& resize_node = match.inputs[0].node;
|
||||||
|
const NodeDef& weights_node = match.inputs[1].node;
|
||||||
|
|
||||||
|
// We'll be reusing the old weights.
|
||||||
|
new_nodes->push_back(weights_node);
|
||||||
|
|
||||||
|
// Create a 'no-op' mirror padding node that has no effect.
|
||||||
|
NodeDef pad_dims_node;
|
||||||
|
pad_dims_node.set_op("Const");
|
||||||
|
pad_dims_node.set_name(conv_node.name() + "_dummy_paddings");
|
||||||
|
SetNodeAttr("dtype", DT_INT32, &pad_dims_node);
|
||||||
|
SetNodeTensorAttr<int32>("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
&pad_dims_node);
|
||||||
|
new_nodes->push_back(pad_dims_node);
|
||||||
|
|
||||||
|
// Set up the new fused version of the convolution op.
|
||||||
|
NodeDef fused_conv;
|
||||||
|
fused_conv.set_op("FusedResizeAndPadConv2D");
|
||||||
|
fused_conv.set_name(match.node.name());
|
||||||
|
AddNodeInput(resize_node.input(0), &fused_conv);
|
||||||
|
AddNodeInput(resize_node.input(1), &fused_conv);
|
||||||
|
AddNodeInput(pad_dims_node.name(), &fused_conv);
|
||||||
|
AddNodeInput(conv_node.input(1), &fused_conv);
|
||||||
|
CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
|
||||||
|
&fused_conv);
|
||||||
|
SetNodeAttr("mode", "REFLECT", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
|
||||||
|
new_nodes->push_back(fused_conv);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
*output_graph_def = replaced_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FusePadAndConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"Conv2D",
|
||||||
|
{
|
||||||
|
{"MirrorPad",
|
||||||
|
{
|
||||||
|
{"*"},
|
||||||
|
{"*"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"*"}
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
// Find all the nodes we expect in the subgraph.
|
||||||
|
const NodeDef& conv_node = match.node;
|
||||||
|
CHECK_EQ("Conv2D", conv_node.op());
|
||||||
|
const NodeDef& mirror_pad_node = match.inputs[0].node;
|
||||||
|
CHECK_EQ("MirrorPad", mirror_pad_node.op());
|
||||||
|
const NodeDef& weights_node = match.inputs[1].node;
|
||||||
|
const NodeDef& input_node = match.inputs[0].inputs[0].node;
|
||||||
|
const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
|
||||||
|
|
||||||
|
// We'll be reusing the old weights and pad dimensions.
|
||||||
|
new_nodes->push_back(weights_node);
|
||||||
|
new_nodes->push_back(input_node);
|
||||||
|
new_nodes->push_back(pad_dims_node);
|
||||||
|
|
||||||
|
// Set up the new fused version of the convolution op.
|
||||||
|
NodeDef fused_conv;
|
||||||
|
fused_conv.set_op("FusedPadConv2D");
|
||||||
|
fused_conv.set_name(match.node.name());
|
||||||
|
AddNodeInput(mirror_pad_node.input(0), &fused_conv);
|
||||||
|
AddNodeInput(mirror_pad_node.input(1), &fused_conv);
|
||||||
|
AddNodeInput(conv_node.input(1), &fused_conv);
|
||||||
|
CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "T", "T", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
|
||||||
|
CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
|
||||||
|
new_nodes->push_back(fused_conv);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
*output_graph_def = replaced_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("fuse_resize_pad_and_conv", FuseResizePadAndConv);
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("fuse_resize_and_conv", FuseResizeAndConv);
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("fuse_pad_and_conv", FusePadAndConv);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
212
tensorflow/tools/graph_transforms/fuse_convolutions_test.cc
Normal file
212
tensorflow/tools/graph_transforms/fuse_convolutions_test.cc
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status FuseResizePadAndConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
Status FuseResizeAndConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
Status FusePadAndConv(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class FuseConvolutionsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestFuseResizePadAndConv() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op,
|
||||||
|
Const(root.WithOpName("size"), {12, 4}),
|
||||||
|
ResizeBilinear::AlignCorners(false));
|
||||||
|
|
||||||
|
Tensor pad_dims_data(DT_INT32, TensorShape({4, 2}));
|
||||||
|
test::FillValues<int32>(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0});
|
||||||
|
Output pad_dims_op = Const(root.WithOpName("pad_dims_op"),
|
||||||
|
Input::Initializer(pad_dims_data));
|
||||||
|
Output pad_op =
|
||||||
|
MirrorPad(root.WithOpName("pad_op"), resize_op, pad_dims_op, "REFLECT");
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
|
||||||
|
test::FillValues<float>(&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef fused_graph_def;
|
||||||
|
TF_ASSERT_OK(FuseResizePadAndConv(original_graph_def, {{}, {"output"}},
|
||||||
|
&fused_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
|
||||||
|
std::vector<Tensor> fused_outputs;
|
||||||
|
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
|
||||||
|
|
||||||
|
for (const NodeDef& node : fused_graph_def.node()) {
|
||||||
|
EXPECT_NE("Conv2D", node.op());
|
||||||
|
EXPECT_NE("MirrorPad", node.op());
|
||||||
|
EXPECT_NE("ResizeBilinear", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestFuseResizeAndConv() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Output resize_op = ResizeBilinear(root.WithOpName("resize_op"), input_op,
|
||||||
|
Const(root.WithOpName("size"), {12, 4}),
|
||||||
|
ResizeBilinear::AlignCorners(false));
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
|
||||||
|
test::FillValues<float>(&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("output"), resize_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef fused_graph_def;
|
||||||
|
TF_ASSERT_OK(FuseResizeAndConv(original_graph_def, {{}, {"output"}},
|
||||||
|
&fused_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
|
||||||
|
std::vector<Tensor> fused_outputs;
|
||||||
|
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
|
||||||
|
|
||||||
|
for (const NodeDef& node : fused_graph_def.node()) {
|
||||||
|
EXPECT_NE("Conv2D", node.op());
|
||||||
|
EXPECT_NE("ResizeBilinear", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestFusePadAndConv() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 2, 3, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Tensor pad_dims_data(DT_INT32, TensorShape({4, 2}));
|
||||||
|
test::FillValues<int32>(&pad_dims_data, {0, 0, 1, 1, 2, 2, 0, 0});
|
||||||
|
Output pad_dims_op = Const(root.WithOpName("pad_dims_op"),
|
||||||
|
Input::Initializer(pad_dims_data));
|
||||||
|
Output pad_op =
|
||||||
|
MirrorPad(root.WithOpName("pad_op"), input_op, pad_dims_op, "REFLECT");
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
|
||||||
|
test::FillValues<float>(&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("output"), pad_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef fused_graph_def;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
FusePadAndConv(original_graph_def, {{}, {"output"}}, &fused_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(fused_session->Create(fused_graph_def));
|
||||||
|
std::vector<Tensor> fused_outputs;
|
||||||
|
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
|
||||||
|
|
||||||
|
for (const NodeDef& node : fused_graph_def.node()) {
|
||||||
|
EXPECT_NE("Conv2D", node.op());
|
||||||
|
EXPECT_NE("MirrorPad", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FuseConvolutionsTest, TestFuseResizePadAndConv) {
|
||||||
|
TestFuseResizePadAndConv();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FuseConvolutionsTest, TestFuseResizeAndConv) { TestFuseResizeAndConv(); }
|
||||||
|
|
||||||
|
TEST_F(FuseConvolutionsTest, TestFusePadAndConv) { TestFusePadAndConv(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
104
tensorflow/tools/graph_transforms/obsfucate_names.cc
Normal file
104
tensorflow/tools/graph_transforms/obsfucate_names.cc
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Renames all nodes not uses as graph inputs or outputs to short numerical
|
||||||
|
// forms.
|
||||||
|
Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
std::unordered_set<string> required_nodes;
|
||||||
|
for (const string& input : context.input_names) {
|
||||||
|
required_nodes.insert(input);
|
||||||
|
}
|
||||||
|
for (const string& output : context.output_names) {
|
||||||
|
required_nodes.insert(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const string& required_node : required_nodes) {
|
||||||
|
LOG(INFO) << "required_node=" << required_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const string valid_chars =
|
||||||
|
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||||
|
const int64 chars_size = valid_chars.size();
|
||||||
|
|
||||||
|
std::map<string, string> new_names;
|
||||||
|
int64 name_index = 0;
|
||||||
|
for (const NodeDef& input_node : input_graph_def.node()) {
|
||||||
|
const string& old_name = input_node.name();
|
||||||
|
string new_name;
|
||||||
|
if (required_nodes.count(old_name)) {
|
||||||
|
new_name = old_name;
|
||||||
|
} else {
|
||||||
|
do {
|
||||||
|
int64 remaining = name_index;
|
||||||
|
new_name = "";
|
||||||
|
while (true) {
|
||||||
|
const int64 remainder = (remaining % chars_size);
|
||||||
|
const char current_char = valid_chars[remainder];
|
||||||
|
new_name = current_char + new_name;
|
||||||
|
remaining /= chars_size;
|
||||||
|
if (remaining <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++name_index;
|
||||||
|
} while (required_nodes.count(new_name));
|
||||||
|
}
|
||||||
|
new_names[old_name] = new_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& input_node : input_graph_def.node()) {
|
||||||
|
NodeDef* node = output_graph_def->mutable_node()->Add();
|
||||||
|
node->CopyFrom(input_node);
|
||||||
|
const string& old_name = input_node.name();
|
||||||
|
node->set_name(new_names[old_name]);
|
||||||
|
node->mutable_input()->Clear();
|
||||||
|
for (const string& input_name : input_node.input()) {
|
||||||
|
string prefix;
|
||||||
|
string input_node_name;
|
||||||
|
string suffix;
|
||||||
|
NodeNamePartsFromInput(input_name, &prefix, &input_node_name, &suffix);
|
||||||
|
if (new_names.count(input_node_name) == 0) {
|
||||||
|
return errors::InvalidArgument("No node named ", input_node_name,
|
||||||
|
" for input to ", old_name);
|
||||||
|
}
|
||||||
|
string new_input_name = prefix + new_names[input_node_name] + suffix;
|
||||||
|
*(node->mutable_input()->Add()) = new_input_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
142
tensorflow/tools/graph_transforms/obsfucate_names_test.cc
Normal file
142
tensorflow/tools/graph_transforms/obsfucate_names_test.cc
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class ObsfucateNamesTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestSimpleTree() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* add_node1 = graph_def.add_node();
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->add_input("add_node2");
|
||||||
|
add_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
add_node3->add_input("const_node4");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node4 = graph_def.add_node();
|
||||||
|
const_node4->set_name("const_node4");
|
||||||
|
const_node4->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node1"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("add_node2"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("add_node3"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node1"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("const_node2"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("const_node3"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("const_node4"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestManyNodes() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
for (int i = 0; i < 1000; ++i) {
|
||||||
|
NodeDef* const_node = graph_def.add_node();
|
||||||
|
const_node->set_name(strings::StrCat("const_node", i));
|
||||||
|
const_node->set_op("Const");
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}},
|
||||||
|
&result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node0"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("const_node500"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node999"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestNameClashes() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
for (int i = 0; i < 1000; ++i) {
|
||||||
|
NodeDef* const_node = graph_def.add_node();
|
||||||
|
const_node->set_name(strings::StrCat("1", i));
|
||||||
|
const_node->set_op("Const");
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("10"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("19"));
|
||||||
|
|
||||||
|
std::unordered_set<string> names;
|
||||||
|
for (const NodeDef& node : result.node()) {
|
||||||
|
EXPECT_EQ(0, names.count(node.name()))
|
||||||
|
<< "Found multiple nodes with name '" << node.name() << "'";
|
||||||
|
names.insert(node.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); }
|
||||||
|
|
||||||
|
TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); }
|
||||||
|
|
||||||
|
TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
922
tensorflow/tools/graph_transforms/quantize_nodes.cc
Normal file
922
tensorflow/tools/graph_transforms/quantize_nodes.cc
Normal file
@ -0,0 +1,922 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Holds the information we need to translate from a float version of this op
|
||||||
|
// into the quantized equivalent.
|
||||||
|
struct QuantizedOpInfo {
|
||||||
|
// The name of the float op.
|
||||||
|
string float_name;
|
||||||
|
// Which attributes to copy directly over.
|
||||||
|
std::vector<string> attrs_to_copy;
|
||||||
|
// Extra data type attributes we need to set.
|
||||||
|
std::vector<std::pair<string, DataType>> dtypes_to_set;
|
||||||
|
// What depth of inputs the op can read in.
|
||||||
|
DataType input_bit_depth;
|
||||||
|
// The depth of the op's quantized outputs.
|
||||||
|
DataType output_bit_depth;
|
||||||
|
// Which inputs (e.g. shapes) aren't involved in the quantization process.
|
||||||
|
std::set<int32> unquantized_inputs;
|
||||||
|
// How the outputs are arranged, either
|
||||||
|
// [input0, input1, min0, max0, min1, max1] for contiguous, or
|
||||||
|
// [input0, input1, min0, min1, max0, max1] for separate.
|
||||||
|
// The separate order is needed because it's the only way to specify unknown
|
||||||
|
// numbers of inputs for ops like Concat.
|
||||||
|
enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Every op that has a quantized equivalent should be listed here, so that the
|
||||||
|
// conversion process can transform them.
|
||||||
|
const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
|
||||||
|
static const std::vector<QuantizedOpInfo> op_list = {
|
||||||
|
{"AvgPool",
|
||||||
|
{"ksize", "strides", "padding"},
|
||||||
|
{{"T", DT_QUINT8}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QUINT8,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"BiasAdd",
|
||||||
|
{},
|
||||||
|
{{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QINT32,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"Concat",
|
||||||
|
{"N"},
|
||||||
|
{{"T", DT_QUINT8}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QUINT8,
|
||||||
|
{0},
|
||||||
|
QuantizedOpInfo::SEPARATE_MIN_MAX},
|
||||||
|
{"Conv2D",
|
||||||
|
{"strides", "padding"},
|
||||||
|
{{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QINT32,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"MatMul",
|
||||||
|
{"transpose_a", "transpose_b"},
|
||||||
|
{{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QINT32,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"MaxPool",
|
||||||
|
{"ksize", "strides", "padding"},
|
||||||
|
{{"T", DT_QUINT8}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QUINT8,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"Relu",
|
||||||
|
{},
|
||||||
|
{{"Tinput", DT_QUINT8}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QUINT8,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"Relu6",
|
||||||
|
{},
|
||||||
|
{{"Tinput", DT_QUINT8}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QUINT8,
|
||||||
|
{},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
{"Reshape",
|
||||||
|
{},
|
||||||
|
{{"T", DT_QUINT8}},
|
||||||
|
DT_QUINT8,
|
||||||
|
DT_QUINT8,
|
||||||
|
{1},
|
||||||
|
QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
|
||||||
|
};
|
||||||
|
return op_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Replaces invalid characters in input names to get a unique node name.
|
||||||
|
string UniqueNodeNameFromInput(const string& input_name) {
|
||||||
|
string prefix;
|
||||||
|
string node_name;
|
||||||
|
string suffix;
|
||||||
|
NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
|
||||||
|
string result;
|
||||||
|
if (prefix == "^") {
|
||||||
|
result += "__hat__";
|
||||||
|
}
|
||||||
|
result += node_name;
|
||||||
|
if (suffix != "") {
|
||||||
|
result += "__port__" + suffix.substr(1, suffix.size() - 1);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pulls two float values from the named parameters, with a lot of checking.
|
||||||
|
Status ExtractRangeFromParams(const TransformFuncContext& context,
|
||||||
|
const string& min_name, const string& max_name,
|
||||||
|
float* min_value, float* max_value,
|
||||||
|
bool* has_range) {
|
||||||
|
// See if we've been given quantized inputs with a known range.
|
||||||
|
const bool has_min = (context.params.count(min_name) != 0);
|
||||||
|
const bool has_max = (context.params.count(max_name) != 0);
|
||||||
|
*has_range = (has_min || has_max);
|
||||||
|
if (!*has_range) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
if (!has_min || !has_max) {
|
||||||
|
return errors::InvalidArgument("You must pass both ", min_name, " and ",
|
||||||
|
max_name, " into quantize_nodes");
|
||||||
|
}
|
||||||
|
std::vector<string> min_strings = context.params.at(min_name);
|
||||||
|
std::vector<string> max_strings = context.params.at(max_name);
|
||||||
|
if ((min_strings.size() != 1) || (max_strings.size() != 1)) {
|
||||||
|
return errors::InvalidArgument("You must pass a single ", min_name,
|
||||||
|
" and single ", max_name,
|
||||||
|
" value into "
|
||||||
|
"quantize_nodes");
|
||||||
|
}
|
||||||
|
if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) {
|
||||||
|
return errors::InvalidArgument("Couldn't decode ", min_name,
|
||||||
|
" as a number: ", min_strings[0]);
|
||||||
|
}
|
||||||
|
if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) {
|
||||||
|
return errors::InvalidArgument("Couldn't decode ", max_name,
|
||||||
|
" as a number: ", max_strings[0]);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AreAttrsEqual(const NodeDef* current_node, const NodeDef* other_node) {
|
||||||
|
if (current_node->attr_size() != other_node->attr_size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
string current_serialized;
|
||||||
|
string other_serialized;
|
||||||
|
for (const auto& attr : other_node->attr()) {
|
||||||
|
auto iter = current_node->attr().find(attr.first);
|
||||||
|
if (iter == current_node->attr().end()) return false;
|
||||||
|
iter->second.SerializeToString(¤t_serialized);
|
||||||
|
attr.second.SerializeToString(&other_serialized);
|
||||||
|
if (current_serialized != other_serialized) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Analyzes all the nodes in the graph to figure out which ones are duplicates
|
||||||
|
// apart from their names. This commonly includes identical Const nodes, but can
|
||||||
|
// also be simple operations that are repeated on multiple outputs of a
|
||||||
|
// particular node. The complexity is managed using a hash function that avoids
|
||||||
|
// the need for any O(n^2) algorithms when identifying duplicates.
|
||||||
|
Status MergeDuplicateNodes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
// Make sure we can look up inputs and outputs quickly.
|
||||||
|
std::set<string> input_names(context.input_names.begin(),
|
||||||
|
context.input_names.end());
|
||||||
|
std::set<string> output_names(context.output_names.begin(),
|
||||||
|
context.output_names.end());
|
||||||
|
GraphDef current_graph_def = input_graph_def;
|
||||||
|
// Keep running the merging until no more duplicates are found.
|
||||||
|
bool any_duplicates_found;
|
||||||
|
do {
|
||||||
|
any_duplicates_found = false;
|
||||||
|
// First arrange all of the nodes by a hash of their contents.
|
||||||
|
std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
|
||||||
|
for (const NodeDef& node : current_graph_def.node()) {
|
||||||
|
NodeDef nameless_node = node;
|
||||||
|
// The name matters if it's being used as an input or output node,
|
||||||
|
// otherwise ignore it when looking for duplicates.
|
||||||
|
if (!input_names.count(node.name()) && !output_names.count(node.name())) {
|
||||||
|
nameless_node.set_name("");
|
||||||
|
}
|
||||||
|
const uint64 hash = HashNodeDef(nameless_node);
|
||||||
|
hashed_nodes[hash].push_back(&node);
|
||||||
|
}
|
||||||
|
// If we have multiple nodes with the same hash, then we know they're
|
||||||
|
// duplicates and can be removed, unless they're stateful.
|
||||||
|
std::map<string, string> inputs_to_rename;
|
||||||
|
GraphDef merged_graph_def;
|
||||||
|
for (const std::pair<uint64, std::vector<const NodeDef*>> hashed_node_info :
|
||||||
|
hashed_nodes) {
|
||||||
|
const std::vector<const NodeDef*>& hash_node_list =
|
||||||
|
hashed_node_info.second;
|
||||||
|
for (int i = 0; i < hash_node_list.size(); ++i) {
|
||||||
|
const NodeDef* current_node = hash_node_list[i];
|
||||||
|
const OpDef* op_def = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
|
||||||
|
const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
|
||||||
|
if (is_duplicate) {
|
||||||
|
const string original_name = hash_node_list[0]->name();
|
||||||
|
inputs_to_rename[current_node->name() + ":*"] = original_name;
|
||||||
|
any_duplicates_found = true;
|
||||||
|
} else {
|
||||||
|
NodeDef* new_node = merged_graph_def.mutable_node()->Add();
|
||||||
|
*new_node = *current_node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update the graph so that any nodes that referred to removed inputs now
|
||||||
|
// pull from the remaining duplicate.
|
||||||
|
RenameNodeInputs(merged_graph_def, inputs_to_rename, ¤t_graph_def);
|
||||||
|
} while (any_duplicates_found);
|
||||||
|
|
||||||
|
*output_graph_def = current_graph_def;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Looks for the patterns that indicate there are two eight-bit ops feeding into
|
||||||
|
// each other, separated by a conversion up to float and back again. These occur
|
||||||
|
// during the initial conversion of ops to their quantized forms. Because we're
|
||||||
|
// only looking at an individual op in that phase and don't know if its inputs
|
||||||
|
// and outputs are eight-bit-capable, we start by converting the actual op into
|
||||||
|
// quantized form, but add float conversions before and after. This pass gets
|
||||||
|
// rid of those conversions if it turns out we do have adjacent ops capable of
|
||||||
|
// eight-bit processing.
|
||||||
|
Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
std::set<string> graph_outputs;
|
||||||
|
for (const string& output_name : context.output_names) {
|
||||||
|
graph_outputs.insert(NodeNameFromInput(output_name));
|
||||||
|
}
|
||||||
|
std::map<string, string> inputs_to_rename;
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"QuantizeV2",
|
||||||
|
{
|
||||||
|
{"Dequantize"},
|
||||||
|
{"Min"},
|
||||||
|
{"Max"},
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[&inputs_to_rename, &graph_outputs](const NodeMatch& match,
|
||||||
|
const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& quantize_node = match.node;
|
||||||
|
const NodeDef& dequantize_node = match.inputs[0].node;
|
||||||
|
inputs_to_rename[quantize_node.name() + ":0"] =
|
||||||
|
dequantize_node.input(0);
|
||||||
|
inputs_to_rename[quantize_node.name() + ":1"] =
|
||||||
|
dequantize_node.input(1);
|
||||||
|
inputs_to_rename[quantize_node.name() + ":2"] =
|
||||||
|
dequantize_node.input(2);
|
||||||
|
|
||||||
|
// Are other sub-graphs using the float intermediate result? If so,
|
||||||
|
// preserve it, but the input renaming still rewires the eight-bit ops
|
||||||
|
// so they don't go through float.
|
||||||
|
if (output_nodes.count(dequantize_node.name()) ||
|
||||||
|
graph_outputs.count(dequantize_node.name())) {
|
||||||
|
CopyOriginalMatch(match, new_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{true}, &replaced_graph_def));
|
||||||
|
|
||||||
|
RenameNodeInputs(replaced_graph_def, inputs_to_rename, output_graph_def);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the user has passed in the input_min and input_max args, then we need to
|
||||||
|
// convert any input placeholders from float to eight bit, so quantized inputs
|
||||||
|
// can be fed directly into the graph.
|
||||||
|
Status QuantizePlaceholders(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
float input_min;
|
||||||
|
float input_max;
|
||||||
|
bool has_input_range;
|
||||||
|
TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
|
||||||
|
&input_min, &input_max,
|
||||||
|
&has_input_range));
|
||||||
|
if (!has_input_range) {
|
||||||
|
*output_graph_def = input_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
std::map<string, string> inputs_to_rename_first_pass;
|
||||||
|
std::map<string, string> inputs_to_rename_second_pass;
|
||||||
|
GraphDef placeholder_graph_def;
|
||||||
|
placeholder_graph_def.Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
if (node.op() != "Placeholder") {
|
||||||
|
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(node);
|
||||||
|
} else {
|
||||||
|
string namespace_prefix = node.name() + "_eightbit";
|
||||||
|
|
||||||
|
NodeDef quantized_placeholder;
|
||||||
|
quantized_placeholder.CopyFrom(node);
|
||||||
|
SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
|
||||||
|
(placeholder_graph_def.mutable_node()->Add())
|
||||||
|
->CopyFrom(quantized_placeholder);
|
||||||
|
|
||||||
|
NodeDef min_node;
|
||||||
|
min_node.set_op("Const");
|
||||||
|
min_node.set_name(namespace_prefix + "/min");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &min_node);
|
||||||
|
Tensor min_tensor(DT_FLOAT, {});
|
||||||
|
min_tensor.flat<float>()(0) = input_min;
|
||||||
|
SetNodeTensorAttr<float>("value", min_tensor, &min_node);
|
||||||
|
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(min_node);
|
||||||
|
|
||||||
|
NodeDef max_node;
|
||||||
|
max_node.set_op("Const");
|
||||||
|
max_node.set_name(namespace_prefix + "/max");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &max_node);
|
||||||
|
Tensor max_tensor(DT_FLOAT, {});
|
||||||
|
max_tensor.flat<float>()(0) = input_max;
|
||||||
|
SetNodeTensorAttr<float>("value", max_tensor, &max_node);
|
||||||
|
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(max_node);
|
||||||
|
|
||||||
|
const string rename_suffix = "__RENAMED_PLACEHOLDER__";
|
||||||
|
NodeDef dequantize_node;
|
||||||
|
dequantize_node.set_op("Dequantize");
|
||||||
|
dequantize_node.set_name(namespace_prefix + "/dequantize");
|
||||||
|
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
|
||||||
|
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
|
||||||
|
AddNodeInput(node.name() + rename_suffix, &dequantize_node);
|
||||||
|
AddNodeInput(min_node.name(), &dequantize_node);
|
||||||
|
AddNodeInput(max_node.name(), &dequantize_node);
|
||||||
|
(placeholder_graph_def.mutable_node()->Add())->CopyFrom(dequantize_node);
|
||||||
|
|
||||||
|
// First make sure that any internal references to the old placeholder
|
||||||
|
// now point to the dequantize result.
|
||||||
|
inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
|
||||||
|
// Then fix up the dequantize op so that it really points to the
|
||||||
|
// placeholder.
|
||||||
|
inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef first_pass_graph_def;
|
||||||
|
RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
|
||||||
|
&first_pass_graph_def);
|
||||||
|
RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
|
||||||
|
output_graph_def);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
|
||||||
|
// an activation layer. To use these during inference, this pass converts those
|
||||||
|
// ops into Requantizes with the trained min/maxes as constant inputs.
|
||||||
|
Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"FakeQuantWithMinMaxVars",
|
||||||
|
{
|
||||||
|
{"*"},
|
||||||
|
{"Const"},
|
||||||
|
{"Const"},
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& fake_quant_node = match.node;
|
||||||
|
const NodeDef& original_op_node = match.inputs[0].node;
|
||||||
|
const NodeDef& fake_quant_min_node = match.inputs[1].node;
|
||||||
|
const NodeDef& fake_quant_max_node = match.inputs[2].node;
|
||||||
|
|
||||||
|
string namespace_prefix = fake_quant_node.name() + "_eightbit";
|
||||||
|
|
||||||
|
new_nodes->push_back(original_op_node);
|
||||||
|
new_nodes->push_back(fake_quant_min_node);
|
||||||
|
new_nodes->push_back(fake_quant_max_node);
|
||||||
|
|
||||||
|
NodeDef quantize_node;
|
||||||
|
quantize_node.set_op("QuantizeV2");
|
||||||
|
quantize_node.set_name(namespace_prefix + "/quantize");
|
||||||
|
SetNodeAttr("T", DT_QINT32, &quantize_node);
|
||||||
|
SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
|
||||||
|
AddNodeInput(fake_quant_node.input(0), &quantize_node);
|
||||||
|
AddNodeInput(fake_quant_min_node.name(), &quantize_node);
|
||||||
|
AddNodeInput(fake_quant_max_node.name(), &quantize_node);
|
||||||
|
new_nodes->push_back(quantize_node);
|
||||||
|
|
||||||
|
NodeDef requantize_node;
|
||||||
|
requantize_node.set_op("Requantize");
|
||||||
|
requantize_node.set_name(namespace_prefix + "/requantize");
|
||||||
|
SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
|
||||||
|
SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
|
||||||
|
AddNodeInput(quantize_node.name() + ":0", &requantize_node);
|
||||||
|
AddNodeInput(quantize_node.name() + ":1", &requantize_node);
|
||||||
|
AddNodeInput(quantize_node.name() + ":2", &requantize_node);
|
||||||
|
AddNodeInput(fake_quant_min_node.name(), &requantize_node);
|
||||||
|
AddNodeInput(fake_quant_max_node.name(), &requantize_node);
|
||||||
|
new_nodes->push_back(requantize_node);
|
||||||
|
|
||||||
|
// Convert the 8-bit result back into float for the final output.
|
||||||
|
NodeDef dequantize_node;
|
||||||
|
dequantize_node.set_op("Dequantize");
|
||||||
|
dequantize_node.set_name(fake_quant_node.name());
|
||||||
|
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
|
||||||
|
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
|
||||||
|
AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
|
||||||
|
AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
|
||||||
|
AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
|
||||||
|
new_nodes->push_back(dequantize_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, output_graph_def));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// We always generate Requantize ops driven by dynamic RequantizationRange
|
||||||
|
// calculations when we produce quantized ops like Conv2D or BiasAdd with
|
||||||
|
// 32-bit results. If there were FakeQuant ops already for those activation
|
||||||
|
// layers, then there will be a later Requantize op with constant min/max
|
||||||
|
// inputs, which is preferable for fast inference. This pass looks for those
|
||||||
|
// later Requantize ops, and replaces the dynamic version with them.
|
||||||
|
Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, // clang-format off
|
||||||
|
{"Requantize",
|
||||||
|
{
|
||||||
|
{"QuantizeV2",
|
||||||
|
{
|
||||||
|
{"Dequantize",
|
||||||
|
{
|
||||||
|
{"Requantize",
|
||||||
|
{
|
||||||
|
{"*"},
|
||||||
|
{"*"},
|
||||||
|
{"*"},
|
||||||
|
{"RequantizationRange"},
|
||||||
|
{"RequantizationRange"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Requantize"},
|
||||||
|
{"Requantize"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"Const"},
|
||||||
|
{"Const"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"QuantizeV2"},
|
||||||
|
{"QuantizeV2"},
|
||||||
|
{"Const"},
|
||||||
|
{"Const"},
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& fake_requantize_node = match.node;
|
||||||
|
const NodeDef& original_op_node =
|
||||||
|
match.inputs[0].inputs[0].inputs[0].inputs[0].node;
|
||||||
|
const NodeDef& fake_requantize_min_node = match.inputs[3].node;
|
||||||
|
const NodeDef& fake_requantize_max_node = match.inputs[4].node;
|
||||||
|
|
||||||
|
new_nodes->push_back(original_op_node);
|
||||||
|
new_nodes->push_back(fake_requantize_min_node);
|
||||||
|
new_nodes->push_back(fake_requantize_max_node);
|
||||||
|
|
||||||
|
NodeDef requantize_node;
|
||||||
|
requantize_node.CopyFrom(fake_requantize_node);
|
||||||
|
requantize_node.mutable_input()->Clear();
|
||||||
|
AddNodeInput(original_op_node.name() + ":0", &requantize_node);
|
||||||
|
AddNodeInput(original_op_node.name() + ":1", &requantize_node);
|
||||||
|
AddNodeInput(original_op_node.name() + ":2", &requantize_node);
|
||||||
|
AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
|
||||||
|
AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
|
||||||
|
new_nodes->push_back(requantize_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, output_graph_def));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
|
||||||
|
// linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
|
||||||
|
// op that we want to apply the trained constant conversions to. This pass tries
|
||||||
|
// to move FakeQuant ops up the input chain, so they're as close as possible to
|
||||||
|
// the 32-bit conversion, and so can be easily merged into the automatic dynamic
|
||||||
|
// Requantizes.
|
||||||
|
Status HoistFakeQuants(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef current_graph_def = input_graph_def;
|
||||||
|
const int max_depth = 3;
|
||||||
|
for (int depth = max_depth; depth > 0; --depth) {
|
||||||
|
OpTypePattern pattern = {"*"};
|
||||||
|
for (int i = 0; i < depth; ++i) {
|
||||||
|
pattern = {"*", {pattern}};
|
||||||
|
}
|
||||||
|
pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
|
||||||
|
GraphDef hoisted_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
current_graph_def, pattern,
|
||||||
|
[depth](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& fake_quant_node = match.node;
|
||||||
|
const NodeDef& fake_quant_min_node = match.inputs[1].node;
|
||||||
|
const NodeDef& fake_quant_max_node = match.inputs[2].node;
|
||||||
|
std::vector<NodeDef> linear_nodes;
|
||||||
|
NodeMatch current_match = match;
|
||||||
|
for (int i = 0; i <= depth; ++i) {
|
||||||
|
linear_nodes.push_back(current_match.inputs[0].node);
|
||||||
|
current_match = current_match.inputs[0];
|
||||||
|
}
|
||||||
|
NodeDef new_fake_quant_node;
|
||||||
|
new_fake_quant_node.CopyFrom(fake_quant_node);
|
||||||
|
new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
|
||||||
|
new_fake_quant_node.set_input(
|
||||||
|
0, linear_nodes[linear_nodes.size() - 2].input(0));
|
||||||
|
new_nodes->push_back(new_fake_quant_node);
|
||||||
|
|
||||||
|
new_nodes->push_back(fake_quant_min_node);
|
||||||
|
new_nodes->push_back(fake_quant_max_node);
|
||||||
|
|
||||||
|
linear_nodes[linear_nodes.size() - 2].set_input(
|
||||||
|
0, new_fake_quant_node.name());
|
||||||
|
linear_nodes.front().set_name(fake_quant_node.name());
|
||||||
|
for (const NodeDef& linear_node : linear_nodes) {
|
||||||
|
new_nodes->push_back(linear_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &hoisted_graph_def));
|
||||||
|
current_graph_def = hoisted_graph_def;
|
||||||
|
}
|
||||||
|
*output_graph_def = current_graph_def;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts any float ops that have eight-bit equivalents into their quantized
|
||||||
|
// forms, so that as much calculation as possible is done in the lower-precision
|
||||||
|
// format.
|
||||||
|
Status QuantizeNodes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
// Loop through all of the quantizable op types, and replace any occurrences
|
||||||
|
// with equivalent sub-graphs with quantized ops at their core. For example
|
||||||
|
// this one-input operation:
|
||||||
|
//
|
||||||
|
// Input(float)
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// Operation
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// (float)
|
||||||
|
//
|
||||||
|
// Will be turned into it's quantized equivalent:
|
||||||
|
//
|
||||||
|
// Input(float) ReshapeDims
|
||||||
|
// +------v v-------------+
|
||||||
|
// | Reshape
|
||||||
|
// | |
|
||||||
|
// | | ReductionDims
|
||||||
|
// | +-----+ |
|
||||||
|
// | | +---c---------+
|
||||||
|
// | v v v v-------+
|
||||||
|
// | Min Max
|
||||||
|
// | +----+ |
|
||||||
|
// v v v--------+
|
||||||
|
// Quantize
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// QuantizedOperation
|
||||||
|
// | | |
|
||||||
|
// v v v
|
||||||
|
// Dequantize
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// (float)
|
||||||
|
//
|
||||||
|
// This keeps the inputs and outputs visible to the rest of the graph in
|
||||||
|
// float
|
||||||
|
// and converts them down to quantized buffers internally for the
|
||||||
|
// computation.
|
||||||
|
// The result will end up with a lot of redundant dequantize/quantize pairs
|
||||||
|
// between adjacent quantized ops, but a later pass removes these where it
|
||||||
|
// can.
|
||||||
|
const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
|
||||||
|
string op_pattern;
|
||||||
|
bool is_first = true;
|
||||||
|
std::map<string, QuantizedOpInfo> op_map;
|
||||||
|
for (const QuantizedOpInfo& op_info : op_list) {
|
||||||
|
strings::StrAppend(&op_pattern, (is_first ? "" : "|"), op_info.float_name);
|
||||||
|
op_map.insert({op_info.float_name, op_info});
|
||||||
|
is_first = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If input_min and input max have been passed in, then we convert all float
|
||||||
|
// Placeholder nodes into quantized versions, with the supplied values as
|
||||||
|
// their range.
|
||||||
|
GraphDef placeholder_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
|
||||||
|
|
||||||
|
// If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
|
||||||
|
// operations like Relu or MaxPool, move them up so that they're as close as
|
||||||
|
// possible to ops with 32-bit outputs like BiasAdd or Conv2D.
|
||||||
|
GraphDef hoisted_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
|
||||||
|
|
||||||
|
// Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
|
||||||
|
// activation layers, into Requantize ops with those ranges instead. This
|
||||||
|
// makes it easier to replace the dynamic range calculations that are used
|
||||||
|
// by default.
|
||||||
|
GraphDef converted_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
|
||||||
|
&converted_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
|
||||||
|
|
||||||
|
// If fallback_min and fallback_max are set, then we'll use hardwired ranges
|
||||||
|
// for all the 32-bit to 8-bit requantizations.
|
||||||
|
float fallback_min;
|
||||||
|
float fallback_max;
|
||||||
|
bool has_fallback_range;
|
||||||
|
TF_RETURN_IF_ERROR(ExtractRangeFromParams(
|
||||||
|
context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
|
||||||
|
&has_fallback_range));
|
||||||
|
|
||||||
|
// Replace all occurrences of the current float op with its quantized
|
||||||
|
// equivalent.
|
||||||
|
GraphDef quantized_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
converted_graph_def, {op_pattern},
|
||||||
|
[&op_map, fallback_min, fallback_max, has_fallback_range](
|
||||||
|
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& float_node = match.node;
|
||||||
|
const QuantizedOpInfo& op_info = op_map[float_node.op()];
|
||||||
|
|
||||||
|
string namespace_prefix = float_node.name() + "_eightbit";
|
||||||
|
|
||||||
|
// Quantize all of the inputs.
|
||||||
|
std::vector<string> quantized_input_names;
|
||||||
|
for (int i = 0; i < float_node.input_size(); ++i) {
|
||||||
|
// Skip any non-float inputs.
|
||||||
|
if (op_info.unquantized_inputs.count(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& input_name = float_node.input(i);
|
||||||
|
string unique_input_name =
|
||||||
|
namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
|
||||||
|
|
||||||
|
// Add some common constants we need for reshaping inputs.
|
||||||
|
NodeDef reshape_dims;
|
||||||
|
reshape_dims.set_op("Const");
|
||||||
|
reshape_dims.set_name(unique_input_name + "/reshape_dims");
|
||||||
|
SetNodeAttr("dtype", DT_INT32, &reshape_dims);
|
||||||
|
Tensor reshape_dims_tensor(DT_INT32, {1});
|
||||||
|
reshape_dims_tensor.flat<int32>()(0) = -1;
|
||||||
|
SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
|
||||||
|
new_nodes->push_back(reshape_dims);
|
||||||
|
|
||||||
|
NodeDef reduction_dims;
|
||||||
|
reduction_dims.set_op("Const");
|
||||||
|
reduction_dims.set_name(unique_input_name + "/reduction_dims");
|
||||||
|
SetNodeAttr("dtype", DT_INT32, &reduction_dims);
|
||||||
|
Tensor reduction_dims_tensor(DT_INT32, {1});
|
||||||
|
reduction_dims_tensor.flat<int32>()(0) = 0;
|
||||||
|
SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
|
||||||
|
&reduction_dims);
|
||||||
|
new_nodes->push_back(reduction_dims);
|
||||||
|
|
||||||
|
NodeDef reshape_node;
|
||||||
|
reshape_node.set_op("Reshape");
|
||||||
|
reshape_node.set_name(unique_input_name + "/reshape");
|
||||||
|
SetNodeAttr("T", DT_FLOAT, &reshape_node);
|
||||||
|
AddNodeInput(input_name, &reshape_node);
|
||||||
|
AddNodeInput(reshape_dims.name(), &reshape_node);
|
||||||
|
new_nodes->push_back(reshape_node);
|
||||||
|
|
||||||
|
NodeDef min_node;
|
||||||
|
min_node.set_op("Min");
|
||||||
|
min_node.set_name(unique_input_name + "/min");
|
||||||
|
SetNodeAttr("T", DT_FLOAT, &min_node);
|
||||||
|
SetNodeAttr("keep_dims", false, &min_node);
|
||||||
|
AddNodeInput(reshape_node.name(), &min_node);
|
||||||
|
AddNodeInput(reduction_dims.name(), &min_node);
|
||||||
|
new_nodes->push_back(min_node);
|
||||||
|
|
||||||
|
NodeDef max_node;
|
||||||
|
max_node.set_op("Max");
|
||||||
|
max_node.set_name(unique_input_name + "/max");
|
||||||
|
SetNodeAttr("T", DT_FLOAT, &max_node);
|
||||||
|
SetNodeAttr("keep_dims", false, &max_node);
|
||||||
|
AddNodeInput(reshape_node.name(), &max_node);
|
||||||
|
AddNodeInput(reduction_dims.name(), &max_node);
|
||||||
|
new_nodes->push_back(max_node);
|
||||||
|
|
||||||
|
NodeDef quantize_node;
|
||||||
|
quantize_node.set_op("QuantizeV2");
|
||||||
|
quantize_node.set_name(unique_input_name + "/quantize");
|
||||||
|
SetNodeAttr("T", DT_QUINT8, &quantize_node);
|
||||||
|
SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
|
||||||
|
AddNodeInput(input_name, &quantize_node);
|
||||||
|
AddNodeInput(min_node.name(), &quantize_node);
|
||||||
|
AddNodeInput(max_node.name(), &quantize_node);
|
||||||
|
new_nodes->push_back(quantize_node);
|
||||||
|
quantized_input_names.push_back(quantize_node.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the quantized version of the current op.
|
||||||
|
NodeDef quantized_main_node;
|
||||||
|
quantized_main_node.set_op("Quantized" + float_node.op());
|
||||||
|
quantized_main_node.set_name(float_node.name() + "/eightbit");
|
||||||
|
for (const string& attr_to_copy : op_info.attrs_to_copy) {
|
||||||
|
CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
|
||||||
|
&quantized_main_node);
|
||||||
|
}
|
||||||
|
for (const std::pair<string, DataType>& dtype_to_set :
|
||||||
|
op_info.dtypes_to_set) {
|
||||||
|
SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
|
||||||
|
&quantized_main_node);
|
||||||
|
}
|
||||||
|
int quantized_input_index = 0;
|
||||||
|
for (int i = 0; i < float_node.input_size(); ++i) {
|
||||||
|
if (op_info.unquantized_inputs.count(i)) {
|
||||||
|
AddNodeInput(float_node.input(i), &quantized_main_node);
|
||||||
|
} else {
|
||||||
|
const string& quantized_input_name =
|
||||||
|
quantized_input_names[quantized_input_index];
|
||||||
|
AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
|
||||||
|
++quantized_input_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
|
||||||
|
for (const string& quantized_input_name : quantized_input_names) {
|
||||||
|
AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
|
||||||
|
AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const string& quantized_input_name : quantized_input_names) {
|
||||||
|
AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
|
||||||
|
}
|
||||||
|
for (const string& quantized_input_name : quantized_input_names) {
|
||||||
|
AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_nodes->push_back(quantized_main_node);
|
||||||
|
|
||||||
|
string eight_bit_node_name;
|
||||||
|
if (op_info.output_bit_depth == DT_QINT32) {
|
||||||
|
// Shrink the range of the output down from 32 bits to 8.
|
||||||
|
string requantize_min_input;
|
||||||
|
string requantize_max_input;
|
||||||
|
if (has_fallback_range) {
|
||||||
|
// Use constant values for the min/max range if they were given.
|
||||||
|
NodeDef fallback_min_node;
|
||||||
|
fallback_min_node.set_op("Const");
|
||||||
|
fallback_min_node.set_name(quantized_main_node.name() +
|
||||||
|
"/fallback_min");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
|
||||||
|
Tensor fallback_min_tensor(DT_FLOAT, {});
|
||||||
|
fallback_min_tensor.flat<float>()(0) = fallback_min;
|
||||||
|
SetNodeTensorAttr<float>("value", fallback_min_tensor,
|
||||||
|
&fallback_min_node);
|
||||||
|
new_nodes->push_back(fallback_min_node);
|
||||||
|
|
||||||
|
NodeDef fallback_max_node;
|
||||||
|
fallback_max_node.set_op("Const");
|
||||||
|
fallback_max_node.set_name(quantized_main_node.name() +
|
||||||
|
"/fallback_max");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
|
||||||
|
Tensor fallback_max_tensor(DT_FLOAT, {});
|
||||||
|
fallback_max_tensor.flat<float>()(0) = fallback_max;
|
||||||
|
SetNodeTensorAttr<float>("value", fallback_max_tensor,
|
||||||
|
&fallback_max_node);
|
||||||
|
new_nodes->push_back(fallback_max_node);
|
||||||
|
|
||||||
|
requantize_min_input = fallback_min_node.name();
|
||||||
|
requantize_max_input = fallback_max_node.name();
|
||||||
|
} else {
|
||||||
|
// Otherwise dynamically measure the range each time.
|
||||||
|
NodeDef requant_range_node;
|
||||||
|
requant_range_node.set_op("RequantizationRange");
|
||||||
|
requant_range_node.set_name(quantized_main_node.name() +
|
||||||
|
"/requant_range");
|
||||||
|
SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
|
||||||
|
AddNodeInput(quantized_main_node.name() + ":0",
|
||||||
|
&requant_range_node);
|
||||||
|
AddNodeInput(quantized_main_node.name() + ":1",
|
||||||
|
&requant_range_node);
|
||||||
|
AddNodeInput(quantized_main_node.name() + ":2",
|
||||||
|
&requant_range_node);
|
||||||
|
new_nodes->push_back(requant_range_node);
|
||||||
|
|
||||||
|
requantize_min_input = requant_range_node.name() + ":0";
|
||||||
|
requantize_max_input = requant_range_node.name() + ":1";
|
||||||
|
}
|
||||||
|
NodeDef requantize_node;
|
||||||
|
requantize_node.set_op("Requantize");
|
||||||
|
requantize_node.set_name(quantized_main_node.name() + "/requantize");
|
||||||
|
SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
|
||||||
|
SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
|
||||||
|
AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
|
||||||
|
AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
|
||||||
|
AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
|
||||||
|
AddNodeInput(requantize_min_input, &requantize_node);
|
||||||
|
AddNodeInput(requantize_max_input, &requantize_node);
|
||||||
|
new_nodes->push_back(requantize_node);
|
||||||
|
eight_bit_node_name = requantize_node.name();
|
||||||
|
} else {
|
||||||
|
eight_bit_node_name = quantized_main_node.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the 8-bit result back into float for the final output.
|
||||||
|
NodeDef dequantize_node;
|
||||||
|
dequantize_node.set_op("Dequantize");
|
||||||
|
dequantize_node.set_name(float_node.name());
|
||||||
|
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
|
||||||
|
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
|
||||||
|
AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
|
||||||
|
AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
|
||||||
|
AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
|
||||||
|
new_nodes->push_back(dequantize_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &quantized_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
|
||||||
|
|
||||||
|
// If we've ended up with two Requantize ops in a row (for example if there
|
||||||
|
// was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
|
||||||
|
// using the trained range from the second op.
|
||||||
|
GraphDef merged_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
|
||||||
|
&merged_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
|
||||||
|
|
||||||
|
// There can be duplicate quantize nodes if multiple ops pull from a single
|
||||||
|
// input, which makes it harder to remove redundant ones, so strip them out.
|
||||||
|
GraphDef deduped_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
|
||||||
|
|
||||||
|
// Look for Dequantizes that immediately go into Quantizes, and remove them
|
||||||
|
// since the two together cancel each other out. This allows us to keep the
|
||||||
|
// data flow in eight bit where two adjacent ops are in eight bit, but still
|
||||||
|
// keep interoperability with float ops.
|
||||||
|
TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
|
||||||
|
output_graph_def));
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
1321
tensorflow/tools/graph_transforms/quantize_nodes_test.cc
Normal file
1321
tensorflow/tools/graph_transforms/quantize_nodes_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
139
tensorflow/tools/graph_transforms/quantize_weights.cc
Normal file
139
tensorflow/tools/graph_transforms/quantize_weights.cc
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Converts any large float constants into eight-bit equivalents, with a
|
||||||
|
// Dequantize op so that subsequent nodes can still access the results in a
|
||||||
|
// float form.
|
||||||
|
Status QuantizeWeights(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, {"Const"},
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& old_const_node = match.node;
|
||||||
|
if (!old_const_node.attr().count("dtype")) {
|
||||||
|
return errors::InvalidArgument("No 'dtype' attribute for Const node ",
|
||||||
|
old_const_node.name());
|
||||||
|
}
|
||||||
|
if (!old_const_node.attr().count("value")) {
|
||||||
|
return errors::InvalidArgument("No 'value' attribute for Const node ",
|
||||||
|
old_const_node.name());
|
||||||
|
}
|
||||||
|
const DataType old_dtype = old_const_node.attr().at("dtype").type();
|
||||||
|
Tensor old_tensor;
|
||||||
|
if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) {
|
||||||
|
return errors::InvalidArgument("Decoding Tensor failed for node",
|
||||||
|
old_const_node.name());
|
||||||
|
}
|
||||||
|
const size_t num_elements = old_tensor.NumElements();
|
||||||
|
// If this isn't a float constant, or it's too small, then reuse the
|
||||||
|
// same node with no changes.
|
||||||
|
if ((old_dtype != DT_FLOAT) || (num_elements < 16)) {
|
||||||
|
new_nodes->push_back(old_const_node);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
const float* old_values = old_tensor.flat<float>().data();
|
||||||
|
float min = std::numeric_limits<float>::max();
|
||||||
|
float max = std::numeric_limits<float>::min();
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
const float value = old_values[i];
|
||||||
|
min = std::min(min, value);
|
||||||
|
max = std::max(max, value);
|
||||||
|
}
|
||||||
|
// min_value == max_value is a tricky case. It can occur for general
|
||||||
|
// tensors, and of course for scalars. The quantized ops cannot deal
|
||||||
|
// with this case, so we set max_value to something else.
|
||||||
|
// It's a tricky question what is the numerically best solution to
|
||||||
|
// deal with this degeneracy.
|
||||||
|
// TODO(petewarden): Better use a tolerance than a hard comparison?
|
||||||
|
if (min == max) {
|
||||||
|
if (std::abs(min) < 0.000001f) {
|
||||||
|
max = min + 1.0f;
|
||||||
|
} else if (min > 0) {
|
||||||
|
max = 2.0f * min;
|
||||||
|
} else {
|
||||||
|
max = min / 2.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor quantized_tensor(DT_QUINT8, old_tensor.shape());
|
||||||
|
FloatTensorToQuantizedInPlace<quint8>(old_tensor, min, max,
|
||||||
|
&quantized_tensor);
|
||||||
|
|
||||||
|
NodeDef quantized_const_node;
|
||||||
|
quantized_const_node.set_op("Const");
|
||||||
|
quantized_const_node.set_name(old_const_node.name() +
|
||||||
|
"_quantized_const");
|
||||||
|
SetNodeAttr("dtype", DT_QUINT8, &quantized_const_node);
|
||||||
|
SetNodeTensorAttr<float>("value", quantized_tensor,
|
||||||
|
&quantized_const_node);
|
||||||
|
new_nodes->push_back(quantized_const_node);
|
||||||
|
|
||||||
|
NodeDef min_node;
|
||||||
|
min_node.set_op("Const");
|
||||||
|
min_node.set_name(old_const_node.name() + "_quantized_min");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &min_node);
|
||||||
|
Tensor min_tensor(DT_FLOAT, {});
|
||||||
|
min_tensor.scalar<float>()() = min;
|
||||||
|
SetNodeTensorAttr<float>("value", min_tensor, &min_node);
|
||||||
|
new_nodes->push_back(min_node);
|
||||||
|
|
||||||
|
NodeDef max_node;
|
||||||
|
max_node.set_op("Const");
|
||||||
|
max_node.set_name(old_const_node.name() + "_quantized_max");
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &max_node);
|
||||||
|
Tensor max_tensor(DT_FLOAT, {});
|
||||||
|
max_tensor.scalar<float>()() = max;
|
||||||
|
SetNodeTensorAttr<float>("value", max_tensor, &max_node);
|
||||||
|
new_nodes->push_back(max_node);
|
||||||
|
|
||||||
|
NodeDef dequantize_node;
|
||||||
|
dequantize_node.set_op("Dequantize");
|
||||||
|
dequantize_node.set_name(old_const_node.name());
|
||||||
|
SetNodeAttr("T", DT_QUINT8, &dequantize_node);
|
||||||
|
SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
|
||||||
|
AddNodeInput(quantized_const_node.name(), &dequantize_node);
|
||||||
|
AddNodeInput(min_node.name(), &dequantize_node);
|
||||||
|
AddNodeInput(max_node.name(), &dequantize_node);
|
||||||
|
new_nodes->push_back(dequantize_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, output_graph_def));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("quantize_weights", QuantizeWeights);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
103
tensorflow/tools/graph_transforms/quantize_weights_test.cc
Normal file
103
tensorflow/tools/graph_transforms/quantize_weights_test.cc
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status QuantizeWeights(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class QuantizeWeightsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestQuantizeWeights() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
|
||||||
|
3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
|
||||||
|
0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
|
||||||
|
0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef quantized_graph_def;
|
||||||
|
TF_ASSERT_OK(QuantizeWeights(original_graph_def, {{}, {"output"}},
|
||||||
|
&quantized_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> quantized_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(quantized_session->Create(quantized_graph_def));
|
||||||
|
std::vector<Tensor> quantized_outputs;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
quantized_session->Run({}, {"output"}, {}, &quantized_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], quantized_outputs[0],
|
||||||
|
0.5);
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(quantized_graph_def, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("input_op"));
|
||||||
|
const NodeDef* q_input_op = node_lookup.at("input_op");
|
||||||
|
EXPECT_EQ(DT_FLOAT, q_input_op->attr().at("dtype").type());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("weights_op"));
|
||||||
|
const NodeDef* q_weights_op = node_lookup.at("weights_op");
|
||||||
|
EXPECT_EQ("Dequantize", q_weights_op->op());
|
||||||
|
const string& weights_const_name =
|
||||||
|
NodeNameFromInput(q_weights_op->input(0));
|
||||||
|
EXPECT_EQ(1, node_lookup.count(weights_const_name));
|
||||||
|
const NodeDef* q_weights_const = node_lookup.at(weights_const_name);
|
||||||
|
EXPECT_EQ("Const", q_weights_const->op());
|
||||||
|
EXPECT_EQ(DT_QUINT8, q_weights_const->attr().at("dtype").type());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(QuantizeWeightsTest, TestQuantizeWeights) { TestQuantizeWeights(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
70
tensorflow/tools/graph_transforms/remove_attribute.cc
Normal file
70
tensorflow/tools/graph_transforms/remove_attribute.cc
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Deletes a given attribute from the specified nodes.
|
||||||
|
Status RemoveAttribute(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
if (!context.params.count("attribute_name") ||
|
||||||
|
(context.params.at("attribute_name").size() != 1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"remove_nodes expects exactly one 'attribute_name' "
|
||||||
|
"argument, e.g. remove_attribute(op_name=Mul, attribute_name=foo)");
|
||||||
|
}
|
||||||
|
|
||||||
|
string op_name;
|
||||||
|
if (context.params.count("op_name")) {
|
||||||
|
if (context.params.at("op_name").size() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"remove_nodes expects a single op_name argument, but found ",
|
||||||
|
context.params.at("op_name").size());
|
||||||
|
}
|
||||||
|
op_name = context.params.at("op_name")[0];
|
||||||
|
} else {
|
||||||
|
op_name = "*";
|
||||||
|
}
|
||||||
|
|
||||||
|
const string attribute_name = context.params.at("attribute_name")[0];
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
if (((op_name == "*") || (op_name == node.op())) &&
|
||||||
|
(node.attr().count(attribute_name))) {
|
||||||
|
new_node->mutable_attr()->erase(attribute_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("remove_attribute", RemoveAttribute);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
123
tensorflow/tools/graph_transforms/remove_attribute_test.cc
Normal file
123
tensorflow/tools/graph_transforms/remove_attribute_test.cc
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RemoveAttribute(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class RemoveAttributeTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestRemoveAttribute() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node1 = graph_def.add_node();
|
||||||
|
mul_node1->set_name("mul_node1");
|
||||||
|
mul_node1->set_op("Mul");
|
||||||
|
mul_node1->add_input("add_node2");
|
||||||
|
mul_node1->add_input("add_node3");
|
||||||
|
SetNodeAttr<int32>("foo", 23, mul_node1);
|
||||||
|
SetNodeAttr<string>("bar", "something", mul_node1);
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
SetNodeAttr<int32>("foo", 46, add_node2);
|
||||||
|
SetNodeAttr<int32>("bob", 23, add_node2);
|
||||||
|
SetNodeAttr<string>("bar", "something else", add_node2);
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
GraphDef wildcard_result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"mul_node1"};
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op_name", {string("*")}}));
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"attribute_name", {string("foo")}}));
|
||||||
|
TF_ASSERT_OK(RemoveAttribute(graph_def, context, &wildcard_result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(wildcard_result, &node_lookup);
|
||||||
|
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
|
||||||
|
|
||||||
|
GraphDef targeted_result;
|
||||||
|
TransformFuncContext targeted_context;
|
||||||
|
targeted_context.input_names = {};
|
||||||
|
targeted_context.output_names = {"mul_node1"};
|
||||||
|
targeted_context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op_name", {string("Mul")}}));
|
||||||
|
targeted_context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"attribute_name", {string("foo")}}));
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
RemoveAttribute(graph_def, targeted_context, &targeted_result));
|
||||||
|
|
||||||
|
MapNamesToNodes(targeted_result, &node_lookup);
|
||||||
|
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RemoveAttributeTest, TestRemoveAttribute) { TestRemoveAttribute(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
47
tensorflow/tools/graph_transforms/remove_device.cc
Normal file
47
tensorflow/tools/graph_transforms/remove_device.cc
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Clears the device field of all ops in the graph.
|
||||||
|
Status RemoveDevice(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
new_node->set_device("");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("remove_device", RemoveDevice);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
95
tensorflow/tools/graph_transforms/remove_device_test.cc
Normal file
95
tensorflow/tools/graph_transforms/remove_device_test.cc
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RemoveDevice(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class RemoveDeviceTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestRemoveDevice() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node1 = graph_def.add_node();
|
||||||
|
mul_node1->set_name("mul_node1");
|
||||||
|
mul_node1->set_op("Mul");
|
||||||
|
mul_node1->set_device("//cpu:0");
|
||||||
|
mul_node1->add_input("add_node2");
|
||||||
|
mul_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
add_node2->set_device("//gpu:1");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"mul_node1"};
|
||||||
|
TF_ASSERT_OK(RemoveDevice(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ("", node_lookup.at("mul_node1")->device());
|
||||||
|
EXPECT_EQ("", node_lookup.at("add_node2")->device());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RemoveDeviceTest, TestRemoveDevice) { TestRemoveDevice(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
94
tensorflow/tools/graph_transforms/remove_nodes.cc
Normal file
94
tensorflow/tools/graph_transforms/remove_nodes.cc
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Deletes any specified types of nodes, unless they're necessary for the
|
||||||
|
// graph's inputs or outputs.
|
||||||
|
Status RemoveNodes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
if (!context.params.count("op")) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"remove_nodes expects at least one 'op'"
|
||||||
|
"argument, e.g. remove_nodes(op=Identity)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we don't get rid of any nodes used as graph inputs or outputs.
|
||||||
|
std::set<string> required_nodes;
|
||||||
|
for (const string& input : context.input_names) {
|
||||||
|
required_nodes.insert(NodeNameFromInput(input));
|
||||||
|
}
|
||||||
|
for (const string& output : context.output_names) {
|
||||||
|
required_nodes.insert(NodeNameFromInput(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> ops_to_remove = context.params.at("op");
|
||||||
|
GraphDef current_graph_def = input_graph_def;
|
||||||
|
for (const string& op : ops_to_remove) {
|
||||||
|
// Keep looking for nodes to remove until there are no more changes.
|
||||||
|
bool any_nodes_removed;
|
||||||
|
do {
|
||||||
|
any_nodes_removed = false;
|
||||||
|
std::map<string, string> inputs_to_rename;
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
current_graph_def, {op, {{"*"}}},
|
||||||
|
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
|
||||||
|
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& replace_node = match.node;
|
||||||
|
// If this node is needed in the inputs or outputs don't replace it.
|
||||||
|
if (required_nodes.count(replace_node.name())) {
|
||||||
|
LOG(INFO) << "Skipping replacement for " << replace_node.name();
|
||||||
|
CopyOriginalMatch(match, new_nodes);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
const NodeDef& input_node = match.inputs[0].node;
|
||||||
|
inputs_to_rename[replace_node.name()] = input_node.name();
|
||||||
|
inputs_to_rename["^" + replace_node.name()] =
|
||||||
|
"^" + input_node.name();
|
||||||
|
new_nodes->push_back(input_node);
|
||||||
|
any_nodes_removed = true;
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{true}, &replaced_graph_def));
|
||||||
|
// Make sure all references to removed nodes now point to their inputs.
|
||||||
|
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
||||||
|
¤t_graph_def);
|
||||||
|
} while (any_nodes_removed);
|
||||||
|
}
|
||||||
|
|
||||||
|
*output_graph_def = current_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("remove_nodes", RemoveNodes);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
222
tensorflow/tools/graph_transforms/remove_nodes_test.cc
Normal file
222
tensorflow/tools/graph_transforms/remove_nodes_test.cc
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RemoveNodes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class RemoveNodesTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestRemoveNodes() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* add_node1 = graph_def.add_node();
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->add_input("add_node2");
|
||||||
|
add_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("identity_node1");
|
||||||
|
add_node2->add_input("identity_node2");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("identity_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* identity_node1 = graph_def.add_node();
|
||||||
|
identity_node1->set_name("identity_node1");
|
||||||
|
identity_node1->set_op("Identity");
|
||||||
|
identity_node1->add_input("const_node1");
|
||||||
|
|
||||||
|
NodeDef* identity_node2 = graph_def.add_node();
|
||||||
|
identity_node2->set_name("identity_node2");
|
||||||
|
identity_node2->set_op("Identity");
|
||||||
|
identity_node2->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* identity_node3 = graph_def.add_node();
|
||||||
|
identity_node3->set_name("identity_node3");
|
||||||
|
identity_node3->set_op("Identity");
|
||||||
|
identity_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"add_node1"};
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
|
||||||
|
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node1"));
|
||||||
|
EXPECT_EQ("add_node2", node_lookup.at("add_node1")->input(0));
|
||||||
|
EXPECT_EQ("add_node3", node_lookup.at("add_node1")->input(1));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node2"));
|
||||||
|
EXPECT_EQ("const_node1", node_lookup.at("add_node2")->input(0));
|
||||||
|
EXPECT_EQ("const_node2", node_lookup.at("add_node2")->input(1));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node3"));
|
||||||
|
EXPECT_EQ("const_node1", node_lookup.at("add_node3")->input(0));
|
||||||
|
EXPECT_EQ("const_node3", node_lookup.at("add_node3")->input(1));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node4"));
|
||||||
|
EXPECT_EQ("add_node2", node_lookup.at("add_node4")->input(0));
|
||||||
|
EXPECT_EQ("add_node3", node_lookup.at("add_node4")->input(1));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("identity_node1"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("identity_node2"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("identity_node3"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node1"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node2"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node3"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRemoveOutputNodes() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("const_node1");
|
||||||
|
add_node->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* identity_node = graph_def.add_node();
|
||||||
|
identity_node->set_name("identity_node");
|
||||||
|
identity_node->set_op("Identity");
|
||||||
|
identity_node->add_input("add_node");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"identity_node"};
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
|
||||||
|
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
|
||||||
|
EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("identity_node"));
|
||||||
|
EXPECT_EQ("add_node", node_lookup.at("identity_node")->input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRemoveChainedNodes() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* identity_node1 = graph_def.add_node();
|
||||||
|
identity_node1->set_name("identity_node1");
|
||||||
|
identity_node1->set_op("Identity");
|
||||||
|
identity_node1->add_input("const_node1");
|
||||||
|
|
||||||
|
NodeDef* identity_node2 = graph_def.add_node();
|
||||||
|
identity_node2->set_name("identity_node2");
|
||||||
|
identity_node2->set_op("Identity");
|
||||||
|
identity_node2->add_input("identity_node1");
|
||||||
|
|
||||||
|
NodeDef* identity_node3 = graph_def.add_node();
|
||||||
|
identity_node3->set_name("identity_node3");
|
||||||
|
identity_node3->set_op("Identity");
|
||||||
|
identity_node3->add_input("identity_node2");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("identity_node3");
|
||||||
|
add_node->add_input("const_node2");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"identity_node"};
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op", {string("Identity")}}));
|
||||||
|
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
|
||||||
|
EXPECT_EQ("const_node2", node_lookup.at("add_node")->input(1));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("identity_node1"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("identity_node2"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("identity_node3"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
|
||||||
|
|
||||||
|
TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
|
||||||
|
|
||||||
|
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
70
tensorflow/tools/graph_transforms/rename_attribute.cc
Normal file
70
tensorflow/tools/graph_transforms/rename_attribute.cc
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
Status RenameAttribute(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
if (!context.params.count("old_attribute_name") ||
|
||||||
|
(context.params.at("old_attribute_name").size() != 1) ||
|
||||||
|
!context.params.count("new_attribute_name") ||
|
||||||
|
(context.params.at("new_attribute_name").size() != 1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"remove_nodes expects exactly one 'old_attribute_name' and one "
|
||||||
|
"'new_attribute_name' argument, e.g. "
|
||||||
|
"remove_attribute(old_attribute_name=foo, new_attribute_name=bar)");
|
||||||
|
}
|
||||||
|
|
||||||
|
string op_name;
|
||||||
|
if (context.params.count("op_name")) {
|
||||||
|
op_name = context.params.at("op_name")[0];
|
||||||
|
} else {
|
||||||
|
op_name = "*";
|
||||||
|
}
|
||||||
|
|
||||||
|
const string old_attribute_name = context.params.at("old_attribute_name")[0];
|
||||||
|
const string new_attribute_name = context.params.at("new_attribute_name")[0];
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
if (((op_name == "*") || (op_name == node.op())) &&
|
||||||
|
(node.attr().count(old_attribute_name))) {
|
||||||
|
AttrValue attribute_value = node.attr().at(old_attribute_name);
|
||||||
|
new_node->mutable_attr()->erase(old_attribute_name);
|
||||||
|
new_node->mutable_attr()->insert({new_attribute_name, attribute_value});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("rename_attribute", RenameAttribute);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
131
tensorflow/tools/graph_transforms/rename_attribute_test.cc
Normal file
131
tensorflow/tools/graph_transforms/rename_attribute_test.cc
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RenameAttribute(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class RenameAttributeTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestRenameAttribute() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node1 = graph_def.add_node();
|
||||||
|
mul_node1->set_name("mul_node1");
|
||||||
|
mul_node1->set_op("Mul");
|
||||||
|
mul_node1->add_input("add_node2");
|
||||||
|
mul_node1->add_input("add_node3");
|
||||||
|
AddNodeAttr<int32>("foo", 23, mul_node1);
|
||||||
|
AddNodeAttr<string>("bar", "something", mul_node1);
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
AddNodeAttr<int32>("foo", 46, add_node2);
|
||||||
|
AddNodeAttr<int32>("bob", 23, add_node2);
|
||||||
|
AddNodeAttr<string>("bar", "something else", add_node2);
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
GraphDef wildcard_result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"mul_node1"};
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op_name", {string("*")}}));
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"old_attribute_name", {string("foo")}}));
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"new_attribute_name", {string("baz")}}));
|
||||||
|
TF_ASSERT_OK(RenameAttribute(graph_def, context, &wildcard_result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(wildcard_result, &node_lookup);
|
||||||
|
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("baz"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
|
||||||
|
|
||||||
|
GraphDef targeted_result;
|
||||||
|
TransformFuncContext targeted_context;
|
||||||
|
targeted_context.input_names = {};
|
||||||
|
targeted_context.output_names = {"mul_node1"};
|
||||||
|
targeted_context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"op_name", {string("Mul")}}));
|
||||||
|
targeted_context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"old_attribute_name", {string("foo")}}));
|
||||||
|
targeted_context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"new_attribute_name", {string("baz")}}));
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
RenameAttribute(graph_def, targeted_context, &targeted_result));
|
||||||
|
|
||||||
|
MapNamesToNodes(targeted_result, &node_lookup);
|
||||||
|
EXPECT_EQ(0, node_lookup.at("mul_node1")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("baz"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("mul_node1")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("foo"));
|
||||||
|
EXPECT_EQ(0, node_lookup.at("add_node2")->attr().count("baz"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bar"));
|
||||||
|
EXPECT_EQ(1, node_lookup.at("add_node2")->attr().count("bob"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RenameAttributeTest, TestRenameAttribute) { TestRenameAttribute(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
60
tensorflow/tools/graph_transforms/rename_op.cc
Normal file
60
tensorflow/tools/graph_transforms/rename_op.cc
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Changes the op type of a specified op.
|
||||||
|
Status RenameOp(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
if (!context.params.count("old_op_name") ||
|
||||||
|
(context.params.at("old_op_name").size() != 1) ||
|
||||||
|
!context.params.count("new_op_name") ||
|
||||||
|
(context.params.at("new_op_name").size() != 1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"remove_nodes expects exactly one 'old_op_name' and 'new_op_name' "
|
||||||
|
"argument, e.g. rename_op(old_op_name=Mul, new_op_name=Multiply)");
|
||||||
|
}
|
||||||
|
|
||||||
|
const string old_op_name = context.params.at("old_op_name")[0];
|
||||||
|
const string new_op_name = context.params.at("new_op_name")[0];
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
if (node.op() == old_op_name) {
|
||||||
|
new_node->set_op(new_op_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("rename_op", RenameOp);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
109
tensorflow/tools/graph_transforms/rename_op_test.cc
Normal file
109
tensorflow/tools/graph_transforms/rename_op_test.cc
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RenameOp(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class RenameOpTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestRenameOp() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node1 = graph_def.add_node();
|
||||||
|
mul_node1->set_name("mul_node1");
|
||||||
|
mul_node1->set_op("Mul");
|
||||||
|
mul_node1->add_input("add_node2");
|
||||||
|
mul_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"mul_node1"};
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"old_op_name", {string("Mul")}}));
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"new_op_name", {string("Multiply")}}));
|
||||||
|
TF_ASSERT_OK(RenameOp(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("mul_node1"));
|
||||||
|
EXPECT_EQ("Multiply", node_lookup.at("mul_node1")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node2"));
|
||||||
|
EXPECT_EQ("Add", node_lookup.at("add_node2")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node3"));
|
||||||
|
EXPECT_EQ("Add", node_lookup.at("add_node3")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node4"));
|
||||||
|
EXPECT_EQ("Add", node_lookup.at("add_node4")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node1"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const_node1")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node2"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const_node2")->op());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node3"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const_node3")->op());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RenameOpTest, TestRenameOp) { TestRenameOp(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
123
tensorflow/tools/graph_transforms/round_weights.cc
Normal file
123
tensorflow/tools/graph_transforms/round_weights.cc
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Rounds any large float constants to the specified number of levels.
|
||||||
|
Status RoundWeights(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
string num_steps_string;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
|
||||||
|
int32 num_steps;
|
||||||
|
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Couldn't interpret the num_steps argument to round_weights as a "
|
||||||
|
"number:",
|
||||||
|
num_steps_string);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
input_graph_def, {"Const"},
|
||||||
|
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
const NodeDef& old_const_node = match.node;
|
||||||
|
if (!old_const_node.attr().count("dtype")) {
|
||||||
|
return errors::InvalidArgument("No 'dtype' attribute for Const node ",
|
||||||
|
old_const_node.name());
|
||||||
|
}
|
||||||
|
if (!old_const_node.attr().count("value")) {
|
||||||
|
return errors::InvalidArgument("No 'value' attribute for Const node ",
|
||||||
|
old_const_node.name());
|
||||||
|
}
|
||||||
|
const DataType old_dtype = old_const_node.attr().at("dtype").type();
|
||||||
|
Tensor old_tensor;
|
||||||
|
if (!old_tensor.FromProto(old_const_node.attr().at("value").tensor())) {
|
||||||
|
return errors::InvalidArgument("Decoding Tensor failed for node",
|
||||||
|
old_const_node.name());
|
||||||
|
}
|
||||||
|
const size_t num_elements = old_tensor.NumElements();
|
||||||
|
// If this isn't a float constant, or it's too small, then reuse the
|
||||||
|
// same node with no changes. The size is important because small
|
||||||
|
// constants tend to be used for more accuracy-sensitive calculations,
|
||||||
|
// and the benefit of shrinking them is very marginal.
|
||||||
|
if ((old_dtype != DT_FLOAT) || (num_elements < 16)) {
|
||||||
|
new_nodes->push_back(old_const_node);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
const float* old_values = old_tensor.flat<float>().data();
|
||||||
|
float min = std::numeric_limits<float>::max();
|
||||||
|
float max = std::numeric_limits<float>::min();
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
const float value = old_values[i];
|
||||||
|
min = std::min(min, value);
|
||||||
|
max = std::max(max, value);
|
||||||
|
}
|
||||||
|
// min_value == max_value is a tricky case. It can occur for general
|
||||||
|
// tensors, and of course for scalars. The quantized ops cannot deal
|
||||||
|
// with this case, so we set max_value to something else.
|
||||||
|
// It's a tricky question what is the numerically best solution to
|
||||||
|
// deal with this degeneracy.
|
||||||
|
// TODO(petewarden): Better use a tolerance than a hard comparison?
|
||||||
|
if (min == max) {
|
||||||
|
if (std::abs(min) < 0.000001f) {
|
||||||
|
max = min + 1.0f;
|
||||||
|
} else if (min > 0) {
|
||||||
|
max = 2.0f * min;
|
||||||
|
} else {
|
||||||
|
min = 2.0f * max;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor rounded_tensor(DT_FLOAT, old_tensor.shape());
|
||||||
|
float* rounded_values = rounded_tensor.flat<float>().data();
|
||||||
|
const float bucket_width = (max - min) / num_steps;
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
const int32 bucket = std::floor((old_values[i] - min) / bucket_width);
|
||||||
|
rounded_values[i] = min + (bucket_width * (bucket + 0.5f));
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef rounded_const_node;
|
||||||
|
rounded_const_node.set_op("Const");
|
||||||
|
rounded_const_node.set_name(old_const_node.name());
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, &rounded_const_node);
|
||||||
|
SetNodeTensorAttr<float>("value", rounded_tensor, &rounded_const_node);
|
||||||
|
new_nodes->push_back(rounded_const_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, output_graph_def));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("round_weights", RoundWeights);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
96
tensorflow/tools/graph_transforms/round_weights_test.cc
Normal file
96
tensorflow/tools/graph_transforms/round_weights_test.cc
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RoundWeights(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class RoundWeightsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestRoundWeights() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
|
||||||
|
-5.0f, -3.0f, -6.0f});
|
||||||
|
Output input_op =
|
||||||
|
Const(root.WithOpName("input_op"), Input::Initializer(input_data));
|
||||||
|
|
||||||
|
Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 10}));
|
||||||
|
test::FillValues<float>(
|
||||||
|
&weights_data,
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
|
||||||
|
3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
|
||||||
|
0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
|
||||||
|
0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
|
||||||
|
Output weights_op =
|
||||||
|
Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
|
||||||
|
|
||||||
|
Output conv_op = Conv2D(root.WithOpName("output"), input_op, weights_op,
|
||||||
|
{1, 1, 1, 1}, "VALID");
|
||||||
|
|
||||||
|
GraphDef original_graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(original_session->Create(original_graph_def));
|
||||||
|
std::vector<Tensor> original_outputs;
|
||||||
|
TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
|
||||||
|
|
||||||
|
GraphDef rounded_graph_def;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
RoundWeights(original_graph_def, {{}, {"output"}}, &rounded_graph_def));
|
||||||
|
|
||||||
|
std::unique_ptr<Session> rounded_session(NewSession(SessionOptions()));
|
||||||
|
TF_ASSERT_OK(rounded_session->Create(rounded_graph_def));
|
||||||
|
std::vector<Tensor> rounded_outputs;
|
||||||
|
TF_ASSERT_OK(rounded_session->Run({}, {"output"}, {}, &rounded_outputs));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(original_outputs[0], rounded_outputs[0], 0.5);
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(rounded_graph_def, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("input_op"));
|
||||||
|
const NodeDef* r_input_op = node_lookup.at("input_op");
|
||||||
|
EXPECT_EQ(DT_FLOAT, r_input_op->attr().at("dtype").type());
|
||||||
|
EXPECT_EQ(1, node_lookup.count("weights_op"));
|
||||||
|
const NodeDef* r_weights_op = node_lookup.at("weights_op");
|
||||||
|
EXPECT_EQ("Const", r_weights_op->op());
|
||||||
|
EXPECT_EQ(DT_FLOAT, r_weights_op->attr().at("dtype").type());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RoundWeightsTest, TestRoundWeights) { TestRoundWeights(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
43
tensorflow/tools/graph_transforms/sort_by_execution_order.cc
Normal file
43
tensorflow/tools/graph_transforms/sort_by_execution_order.cc
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// This is a thin wrapper with the standard TransformFunc interface to the
|
||||||
|
// underlying utility function. The only difference is that we don't use the
|
||||||
|
// input or output name arguments.
|
||||||
|
Status SortByExecutionOrderWithUnusedContext(
|
||||||
|
const GraphDef& input_graph_def, const TransformFuncContext& unused_context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
return SortByExecutionOrder(input_graph_def, output_graph_def);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("sort_by_execution_order",
|
||||||
|
SortByExecutionOrderWithUnusedContext);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,206 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
class SortByExecutionOrderTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void GetOrder(const GraphDef& graph_def, std::map<string, int>* order) {
|
||||||
|
for (int i = 0; i < graph_def.node_size(); ++i) {
|
||||||
|
const NodeDef& node = graph_def.node(i);
|
||||||
|
(*order)[node.name()] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSimpleAdd() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("a_node");
|
||||||
|
add_node->add_input("b_node");
|
||||||
|
|
||||||
|
NodeDef* b_node = graph_def.add_node();
|
||||||
|
b_node->set_name("b_node");
|
||||||
|
b_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* a_node = graph_def.add_node();
|
||||||
|
a_node->set_name("a_node");
|
||||||
|
a_node->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
|
||||||
|
|
||||||
|
std::map<string, int> order;
|
||||||
|
GetOrder(result, &order);
|
||||||
|
EXPECT_EQ(2, order["add_node"]);
|
||||||
|
EXPECT_GT(2, order["a_node"]);
|
||||||
|
EXPECT_GT(2, order["b_node"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSimpleLinear() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* negative_node = graph_def.add_node();
|
||||||
|
negative_node->set_name("negative_node");
|
||||||
|
negative_node->set_op("Negative");
|
||||||
|
negative_node->add_input("sqrt_node");
|
||||||
|
|
||||||
|
NodeDef* relu_node = graph_def.add_node();
|
||||||
|
relu_node->set_name("relu_node");
|
||||||
|
relu_node->set_op("Relu");
|
||||||
|
relu_node->add_input("const_node");
|
||||||
|
|
||||||
|
NodeDef* sqrt_node = graph_def.add_node();
|
||||||
|
sqrt_node->set_name("sqrt_node");
|
||||||
|
sqrt_node->set_op("Sqrt");
|
||||||
|
sqrt_node->add_input("relu_node");
|
||||||
|
|
||||||
|
NodeDef* const_node = graph_def.add_node();
|
||||||
|
const_node->set_name("const_node");
|
||||||
|
const_node->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
|
||||||
|
|
||||||
|
std::map<string, int> order;
|
||||||
|
GetOrder(result, &order);
|
||||||
|
EXPECT_EQ(3, order["negative_node"]);
|
||||||
|
EXPECT_EQ(2, order["sqrt_node"]);
|
||||||
|
EXPECT_EQ(1, order["relu_node"]);
|
||||||
|
EXPECT_EQ(0, order["const_node"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSimpleTree() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* add_node1 = graph_def.add_node();
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->add_input("add_node2");
|
||||||
|
add_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
add_node3->add_input("const_node4");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node4 = graph_def.add_node();
|
||||||
|
const_node4->set_name("const_node4");
|
||||||
|
const_node4->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
|
||||||
|
|
||||||
|
std::map<string, int> order;
|
||||||
|
GetOrder(result, &order);
|
||||||
|
EXPECT_EQ(6, order["add_node1"]);
|
||||||
|
EXPECT_GT(6, order["add_node2"]);
|
||||||
|
EXPECT_GT(6, order["add_node3"]);
|
||||||
|
EXPECT_GT(5, order["const_node1"]);
|
||||||
|
EXPECT_GT(5, order["const_node2"]);
|
||||||
|
EXPECT_GT(5, order["const_node3"]);
|
||||||
|
EXPECT_GT(5, order["const_node4"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCommonAncestor() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* add_node1 = graph_def.add_node();
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->add_input("add_node2");
|
||||||
|
add_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(SortByExecutionOrder(graph_def, &result));
|
||||||
|
|
||||||
|
std::map<string, int> order;
|
||||||
|
GetOrder(result, &order);
|
||||||
|
EXPECT_EQ(5, order["add_node1"]);
|
||||||
|
EXPECT_GT(5, order["add_node2"]);
|
||||||
|
EXPECT_GT(5, order["add_node3"]);
|
||||||
|
EXPECT_GT(4, order["const_node2"]);
|
||||||
|
EXPECT_GT(4, order["const_node3"]);
|
||||||
|
EXPECT_GT(3, order["const_node1"]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(SortByExecutionOrderTest, TestSimpleAdd) { TestSimpleAdd(); }
|
||||||
|
|
||||||
|
TEST_F(SortByExecutionOrderTest, TestSimpleLinear) { TestSimpleLinear(); }
|
||||||
|
|
||||||
|
TEST_F(SortByExecutionOrderTest, TestSimpleTree) { TestSimpleTree(); }
|
||||||
|
|
||||||
|
TEST_F(SortByExecutionOrderTest, TestCommonAncestor) { TestCommonAncestor(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
215
tensorflow/tools/graph_transforms/strip_unused_nodes.cc
Normal file
215
tensorflow/tools/graph_transforms/strip_unused_nodes.cc
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status TypeForPlaceholder(const TransformFuncContext& context,
|
||||||
|
const string& node_name, DataType* result) {
|
||||||
|
// If we don't find anything else, return float.
|
||||||
|
*result = DT_FLOAT;
|
||||||
|
|
||||||
|
// Check to see if we have been given a default for all placeholders.
|
||||||
|
if (context.params.count("type")) {
|
||||||
|
if (context.params.at("type").size() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"You must pass no more than one default 'type' to "
|
||||||
|
"strip_unused_nodes");
|
||||||
|
}
|
||||||
|
const string& type_string = context.params.at("type")[0];
|
||||||
|
if (!DataTypeFromString(type_string, result)) {
|
||||||
|
return errors::InvalidArgument("Couldn't understand type argument '",
|
||||||
|
type_string, "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See if there's a particular type specified for this placeholder.
|
||||||
|
if (context.params.count("name") || context.params.count("type_for_name")) {
|
||||||
|
if (!context.params.count("name") ||
|
||||||
|
!context.params.count("type_for_name") ||
|
||||||
|
(context.params.at("type_for_name").size() !=
|
||||||
|
context.params.at("name").size())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"You must pass a 'type_for_name' arg for every 'name', e.g. "
|
||||||
|
"strip_unused_nodes(name=foo, type_for_name=float, name=bar, "
|
||||||
|
"type_for_name=quint8");
|
||||||
|
}
|
||||||
|
const int name_count = context.params.at("name").size();
|
||||||
|
for (int i = 0; i < name_count; ++i) {
|
||||||
|
if (context.params.at("name")[i] == node_name) {
|
||||||
|
const string& type_string = context.params.at("type_for_name")[i];
|
||||||
|
if (!DataTypeFromString(type_string, result)) {
|
||||||
|
return errors::InvalidArgument("Couldn't understand type argument '",
|
||||||
|
type_string, "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Takes a comma-separated string of numbers and parses them into a shape.
|
||||||
|
bool TensorShapeFromString(const string& shape_string, TensorShape* result) {
|
||||||
|
if (shape_string == "") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<int64> dims;
|
||||||
|
if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*result = TensorShape(dims);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ShapeForPlaceholder(const TransformFuncContext& context,
|
||||||
|
const string& node_name, TensorShape* result) {
|
||||||
|
// If we don't find anything else, return scalar.
|
||||||
|
*result = {};
|
||||||
|
|
||||||
|
// Check to see if we have been given a default for all placeholders.
|
||||||
|
if (context.params.count("type")) {
|
||||||
|
if (context.params.at("shape").size() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"You must pass no more than one default 'shape' to "
|
||||||
|
"strip_unused_nodes");
|
||||||
|
}
|
||||||
|
const string& shape_string = context.params.at("shape")[0];
|
||||||
|
if (!TensorShapeFromString(shape_string, result)) {
|
||||||
|
return errors::InvalidArgument("Couldn't understand shape argument '",
|
||||||
|
shape_string, "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See if there's a particular type specified for this placeholder.
|
||||||
|
if (context.params.count("name") || context.params.count("type_for_name")) {
|
||||||
|
if (!context.params.count("name") ||
|
||||||
|
!context.params.count("type_for_name") ||
|
||||||
|
(context.params.at("type_for_name").size() !=
|
||||||
|
context.params.at("name").size())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"You must pass a 'shape_for_name' arg for every 'name', e.g. "
|
||||||
|
"strip_unused_nodes(name=foo, shape_for_name=\"2,2,1\", name=bar, "
|
||||||
|
"shape_for_name=\"1\"");
|
||||||
|
}
|
||||||
|
const int name_count = context.params.at("name").size();
|
||||||
|
for (int i = 0; i < name_count; ++i) {
|
||||||
|
if (context.params.at("name")[i] == node_name) {
|
||||||
|
const string& shape_string = context.params.at("shape_for_name")[i];
|
||||||
|
if (!TensorShapeFromString(shape_string, result)) {
|
||||||
|
return errors::InvalidArgument("Couldn't understand shape argument '",
|
||||||
|
shape_string, "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Delete any nodes that don't contribute to the inference result.
|
||||||
|
Status StripUnusedNodes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
std::set<string> required_nodes;
|
||||||
|
std::set<string> input_nodes;
|
||||||
|
for (const string& input : context.input_names) {
|
||||||
|
required_nodes.insert(NodeNameFromInput(input));
|
||||||
|
input_nodes.insert(NodeNameFromInput(input));
|
||||||
|
}
|
||||||
|
for (const string& output : context.output_names) {
|
||||||
|
required_nodes.insert(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(input_graph_def, &node_lookup);
|
||||||
|
|
||||||
|
std::vector<string> current_inputs;
|
||||||
|
for (const string& output_name : context.output_names) {
|
||||||
|
current_inputs.push_back(NodeNameFromInput(output_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!current_inputs.empty()) {
|
||||||
|
std::set<string> next_inputs;
|
||||||
|
for (const string& current_input : current_inputs) {
|
||||||
|
required_nodes.insert(current_input);
|
||||||
|
if (input_nodes.count(current_input)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!node_lookup.count(current_input)) {
|
||||||
|
return errors::InvalidArgument("Input node ", current_input,
|
||||||
|
" not found in graph");
|
||||||
|
}
|
||||||
|
const NodeDef* current_node = node_lookup[current_input];
|
||||||
|
for (const string& input_name : current_node->input()) {
|
||||||
|
string input_node_name = NodeNameFromInput(input_name);
|
||||||
|
if (!required_nodes.count(input_node_name)) {
|
||||||
|
next_inputs.insert(input_node_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current_inputs =
|
||||||
|
std::vector<string>(next_inputs.begin(), next_inputs.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef filtered_graph_def;
|
||||||
|
FilterGraphDef(input_graph_def,
|
||||||
|
[&](const NodeDef& node) {
|
||||||
|
return required_nodes.count(node.name()) > 0;
|
||||||
|
},
|
||||||
|
&filtered_graph_def);
|
||||||
|
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : filtered_graph_def.node()) {
|
||||||
|
if (input_nodes.count(node.name())) {
|
||||||
|
NodeDef placeholder_node;
|
||||||
|
if (node.op() == "Placeholder") {
|
||||||
|
placeholder_node.CopyFrom(node);
|
||||||
|
} else {
|
||||||
|
placeholder_node.set_op("Placeholder");
|
||||||
|
placeholder_node.set_name(node.name());
|
||||||
|
DataType type;
|
||||||
|
TF_RETURN_IF_ERROR(TypeForPlaceholder(context, node.name(), &type));
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(ShapeForPlaceholder(context, node.name(), &shape));
|
||||||
|
SetNodeAttr("dtype", type, &placeholder_node);
|
||||||
|
SetNodeAttr("shape", shape, &placeholder_node);
|
||||||
|
}
|
||||||
|
*(output_graph_def->mutable_node()->Add()) = placeholder_node;
|
||||||
|
} else {
|
||||||
|
*(output_graph_def->mutable_node()->Add()) = node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("strip_unused_nodes", StripUnusedNodes);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
286
tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc
Normal file
286
tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status StripUnusedNodes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class StripUnusedNodesTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestSimpleAdd() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("a_node");
|
||||||
|
add_node->add_input("b_node");
|
||||||
|
|
||||||
|
NodeDef* a_node = graph_def.add_node();
|
||||||
|
a_node->set_name("a_node");
|
||||||
|
a_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* b_node = graph_def.add_node();
|
||||||
|
b_node->set_name("b_node");
|
||||||
|
b_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* c_node = graph_def.add_node();
|
||||||
|
c_node->set_name("c_node");
|
||||||
|
c_node->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(StripUnusedNodes(graph_def, {{}, {"add_node"}}, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("a_node"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("b_node"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("c_node"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCommonAncestor() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* add_node1 = graph_def.add_node();
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->add_input("add_node2");
|
||||||
|
add_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* dangling_input = graph_def.add_node();
|
||||||
|
dangling_input->set_name("dangling_input");
|
||||||
|
dangling_input->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(StripUnusedNodes(
|
||||||
|
graph_def, {{"dangling_input"}, {"add_node1"}}, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node1"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node2"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node3"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("add_node4"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node1"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node2"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const_node3"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("const_node4"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("dangling_input"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSimplePlaceholder() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("mul_node");
|
||||||
|
add_node->add_input("a_node");
|
||||||
|
|
||||||
|
NodeDef* mul_node = graph_def.add_node();
|
||||||
|
mul_node->set_name("mul_node");
|
||||||
|
mul_node->set_op("Mul");
|
||||||
|
mul_node->add_input("b_node");
|
||||||
|
mul_node->add_input("c_node");
|
||||||
|
|
||||||
|
NodeDef* a_node = graph_def.add_node();
|
||||||
|
a_node->set_name("a_node");
|
||||||
|
a_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* b_node = graph_def.add_node();
|
||||||
|
b_node->set_name("b_node");
|
||||||
|
b_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* c_node = graph_def.add_node();
|
||||||
|
c_node->set_name("c_node");
|
||||||
|
c_node->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
StripUnusedNodes(graph_def, {{"mul_node"}, {"add_node"}}, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("mul_node"));
|
||||||
|
EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
|
||||||
|
EXPECT_EQ(DT_FLOAT, node_lookup["mul_node"]->attr().at("dtype").type());
|
||||||
|
EXPECT_EQ(TensorShape({}),
|
||||||
|
TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("a_node"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("b_node"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("c_node"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestPlaceholderDefaultArgs() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("mul_node");
|
||||||
|
add_node->add_input("a_node");
|
||||||
|
|
||||||
|
NodeDef* mul_node = graph_def.add_node();
|
||||||
|
mul_node->set_name("mul_node");
|
||||||
|
mul_node->set_op("Mul");
|
||||||
|
mul_node->add_input("b_node");
|
||||||
|
mul_node->add_input("c_node");
|
||||||
|
|
||||||
|
NodeDef* a_node = graph_def.add_node();
|
||||||
|
a_node->set_name("a_node");
|
||||||
|
a_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* b_node = graph_def.add_node();
|
||||||
|
b_node->set_name("b_node");
|
||||||
|
b_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* c_node = graph_def.add_node();
|
||||||
|
c_node->set_name("c_node");
|
||||||
|
c_node->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(StripUnusedNodes(graph_def,
|
||||||
|
{{"mul_node"},
|
||||||
|
{"add_node"},
|
||||||
|
{{"type", {"int32"}}, {"shape", {"1,2,3"}}}},
|
||||||
|
&result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("mul_node"));
|
||||||
|
EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
|
||||||
|
EXPECT_EQ(DT_INT32, node_lookup["mul_node"]->attr().at("dtype").type());
|
||||||
|
EXPECT_EQ(TensorShape({1, 2, 3}),
|
||||||
|
TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("a_node"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("b_node"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("c_node"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestPlaceholderNamedArgs() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("mul_node");
|
||||||
|
add_node->add_input("a_node");
|
||||||
|
|
||||||
|
NodeDef* mul_node = graph_def.add_node();
|
||||||
|
mul_node->set_name("mul_node");
|
||||||
|
mul_node->set_op("Mul");
|
||||||
|
mul_node->add_input("b_node");
|
||||||
|
mul_node->add_input("c_node");
|
||||||
|
|
||||||
|
NodeDef* a_node = graph_def.add_node();
|
||||||
|
a_node->set_name("a_node");
|
||||||
|
a_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* b_node = graph_def.add_node();
|
||||||
|
b_node->set_name("b_node");
|
||||||
|
b_node->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* c_node = graph_def.add_node();
|
||||||
|
c_node->set_name("c_node");
|
||||||
|
c_node->set_op("Const");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(StripUnusedNodes(graph_def,
|
||||||
|
{{"mul_node", "a_node"},
|
||||||
|
{"add_node"},
|
||||||
|
{{"name", {"a_node", "mul_node"}},
|
||||||
|
{"type_for_name", {"int64", "quint8"}},
|
||||||
|
{"shape_for_name", {"1,2", "1, 2, 3"}}}},
|
||||||
|
&result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("mul_node"));
|
||||||
|
EXPECT_EQ("Placeholder", node_lookup["mul_node"]->op());
|
||||||
|
EXPECT_EQ(DT_QUINT8, node_lookup["mul_node"]->attr().at("dtype").type());
|
||||||
|
EXPECT_EQ(TensorShape({1, 2, 3}),
|
||||||
|
TensorShape(node_lookup["mul_node"]->attr().at("shape").shape()));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("a_node"));
|
||||||
|
EXPECT_EQ("Placeholder", node_lookup["a_node"]->op());
|
||||||
|
EXPECT_EQ(DT_INT64, node_lookup["a_node"]->attr().at("dtype").type());
|
||||||
|
EXPECT_EQ(TensorShape({1, 2}),
|
||||||
|
TensorShape(node_lookup["a_node"]->attr().at("shape").shape()));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("b_node"));
|
||||||
|
EXPECT_EQ(0, node_lookup.count("c_node"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(StripUnusedNodesTest, TestSimpleAdd) { TestSimpleAdd(); }
|
||||||
|
|
||||||
|
TEST_F(StripUnusedNodesTest, TestCommonAncestor) { TestCommonAncestor(); }
|
||||||
|
|
||||||
|
TEST_F(StripUnusedNodesTest, TestSimplePlaceholder) { TestSimplePlaceholder(); }
|
||||||
|
|
||||||
|
TEST_F(StripUnusedNodesTest, TestPlaceholderDefaultArgs) {
|
||||||
|
TestPlaceholderDefaultArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(StripUnusedNodesTest, TestPlaceholderNamedArgs) {
|
||||||
|
TestPlaceholderNamedArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
205
tensorflow/tools/graph_transforms/summarize_graph_main.cc
Normal file
205
tensorflow/tools/graph_transforms/summarize_graph_main.cc
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This program prints out a summary of a GraphDef file's contents, listing
|
||||||
|
// things that are useful for debugging and reusing the model it contains. For
|
||||||
|
// example it looks at the graph structure and op types to figure out likely
|
||||||
|
// input and output nodes, and shows which ops are used by the graph. To use it,
|
||||||
|
// run something like this:
|
||||||
|
//
|
||||||
|
// bazel build tensorflow/tools/graph_transforms:summarize_graph
|
||||||
|
// bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
|
||||||
|
// --in_graph=my_graph.pb
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status SummarizeGraph(const GraphDef& graph) {
|
||||||
|
std::vector<const NodeDef*> placeholders;
|
||||||
|
for (const NodeDef& node : graph.node()) {
|
||||||
|
if (node.op() == "Placeholder") {
|
||||||
|
placeholders.push_back(&node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (placeholders.empty()) {
|
||||||
|
std::cout << "No inputs spotted." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << "Found " << placeholders.size() << " possible inputs: ";
|
||||||
|
for (const NodeDef* node : placeholders) {
|
||||||
|
TensorShape shape;
|
||||||
|
if (node->attr().count("shape")) {
|
||||||
|
TensorShapeProto shape_proto = node->attr().at("shape").shape();
|
||||||
|
shape = TensorShape(shape_proto);
|
||||||
|
}
|
||||||
|
DataType dtype = node->attr().at("dtype").type();
|
||||||
|
std::cout << "(name=" << node->name();
|
||||||
|
std::cout << ", type=" << DataTypeString(dtype) << "(" << dtype << ")";
|
||||||
|
std::cout << ", shape=" << shape.DebugString() << ") ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<string, std::vector<const NodeDef*>> output_map;
|
||||||
|
MapNodesToOutputs(graph, &output_map);
|
||||||
|
std::vector<const NodeDef*> outputs;
|
||||||
|
for (const NodeDef& node : graph.node()) {
|
||||||
|
if (output_map.count(node.name()) == 0) {
|
||||||
|
outputs.push_back(&node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outputs.empty()) {
|
||||||
|
std::cout << "No outputs spotted." << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << "Found " << outputs.size() << " possible outputs: ";
|
||||||
|
for (const NodeDef* node : outputs) {
|
||||||
|
std::cout << "(name=" << node->name();
|
||||||
|
std::cout << ", op=" << node->op() << ") ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int const_count = 0;
|
||||||
|
int variable_count = 0;
|
||||||
|
int identity_count = 0;
|
||||||
|
int control_edge_count = 0;
|
||||||
|
std::map<string, int> device_counts;
|
||||||
|
for (const NodeDef& node : graph.node()) {
|
||||||
|
if (node.op() == "Const") {
|
||||||
|
++const_count;
|
||||||
|
} else if (node.op() == "Variable") {
|
||||||
|
++variable_count;
|
||||||
|
} else if (node.op() == "Identity") {
|
||||||
|
++identity_count;
|
||||||
|
}
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
if (input.substr(0, 1) == "^") {
|
||||||
|
++control_edge_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (node.device() != "") {
|
||||||
|
++device_counts[node.device()];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Found " << const_count << " consts, " << variable_count
|
||||||
|
<< " variables, " << identity_count << " identities, and "
|
||||||
|
<< control_edge_count << " control_edges" << std::endl;
|
||||||
|
if (!device_counts.empty()) {
|
||||||
|
for (const auto& device_info : device_counts) {
|
||||||
|
std::cout << device_info.second << " nodes assigned to device '"
|
||||||
|
<< device_info.first << "'";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<string, string>> invalid_inputs;
|
||||||
|
FindInvalidInputs(graph, &invalid_inputs);
|
||||||
|
if (!invalid_inputs.empty()) {
|
||||||
|
for (const std::pair<string, string>& invalid_input : invalid_inputs) {
|
||||||
|
std::cout << "Invalid input " << invalid_input.second << " for node "
|
||||||
|
<< invalid_input.first << std::endl;
|
||||||
|
}
|
||||||
|
return errors::Internal(
|
||||||
|
"Invalid graph with inputs referring to nonexistent nodes");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<string, int> op_counts;
|
||||||
|
for (const NodeDef& node : graph.node()) {
|
||||||
|
++op_counts[node.op()];
|
||||||
|
}
|
||||||
|
std::vector<std::pair<string, int>> op_counts_vec(op_counts.begin(),
|
||||||
|
op_counts.end());
|
||||||
|
std::sort(op_counts_vec.begin(), op_counts_vec.end(),
|
||||||
|
[](std::pair<string, int> a, std::pair<string, int> b) {
|
||||||
|
return (a.second > b.second);
|
||||||
|
});
|
||||||
|
std::cout << "Op types used: ";
|
||||||
|
bool is_first = true;
|
||||||
|
for (const std::pair<string, int>& op_count : op_counts_vec) {
|
||||||
|
if (!is_first) {
|
||||||
|
std::cout << ", ";
|
||||||
|
} else {
|
||||||
|
is_first = false;
|
||||||
|
}
|
||||||
|
std::cout << op_count.second << " " << op_count.first;
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
|
||||||
|
string in_graph = "";
|
||||||
|
string out_graph = "";
|
||||||
|
string inputs_string = "";
|
||||||
|
string outputs_string = "";
|
||||||
|
string transforms_string = "";
|
||||||
|
std::vector<Flag> flag_list = {
|
||||||
|
Flag("in_graph", &in_graph, "input graph file name"),
|
||||||
|
};
|
||||||
|
string usage = Flags::Usage(argv[0], flag_list);
|
||||||
|
|
||||||
|
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
|
||||||
|
// We need to call this to set up global state for TensorFlow.
|
||||||
|
port::InitMain(argv[0], &argc, &argv);
|
||||||
|
|
||||||
|
if (!parse_result) {
|
||||||
|
LOG(ERROR) << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (argc > 1) {
|
||||||
|
LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (in_graph.empty()) {
|
||||||
|
LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
|
||||||
|
if (!load_status.ok()) {
|
||||||
|
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
|
||||||
|
<< load_status.error_message();
|
||||||
|
LOG(ERROR) << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status summarize_result = SummarizeGraph(graph_def);
|
||||||
|
if (!summarize_result.ok()) {
|
||||||
|
LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
return tensorflow::graph_transforms::ParseFlagsAndSummarizeGraph(argc, argv);
|
||||||
|
}
|
280
tensorflow/tools/graph_transforms/transform_graph.cc
Normal file
280
tensorflow/tools/graph_transforms/transform_graph.cc
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/strings/scanner.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
using tensorflow::strings::Scanner;
|
||||||
|
|
||||||
|
Status ParseTransformParameters(const string& transforms_string,
|
||||||
|
TransformParameters* params_list) {
|
||||||
|
params_list->clear();
|
||||||
|
enum {
|
||||||
|
TRANSFORM_NAME,
|
||||||
|
TRANSFORM_PARAM_NAME,
|
||||||
|
TRANSFORM_PARAM_VALUE,
|
||||||
|
} state = TRANSFORM_NAME;
|
||||||
|
StringPiece remaining(transforms_string);
|
||||||
|
StringPiece match;
|
||||||
|
StringPiece transform_name;
|
||||||
|
StringPiece parameter_name;
|
||||||
|
StringPiece parameter_value;
|
||||||
|
TransformFuncParameters func_parameters;
|
||||||
|
while (!remaining.empty()) {
|
||||||
|
if (state == TRANSFORM_NAME) {
|
||||||
|
// Reset the list of parameters.
|
||||||
|
func_parameters.clear();
|
||||||
|
// Eat up any leading spaces.
|
||||||
|
Scanner(remaining).Any(Scanner::SPACE).GetResult(&remaining, &match);
|
||||||
|
// See if we have a valid transform name.
|
||||||
|
const bool found_transform_name =
|
||||||
|
Scanner(remaining)
|
||||||
|
.Any(Scanner::LETTER_DIGIT_UNDERSCORE)
|
||||||
|
.GetResult(&remaining, &transform_name);
|
||||||
|
if (!found_transform_name) {
|
||||||
|
return errors::InvalidArgument("Looking for transform name, but found ",
|
||||||
|
remaining.ToString().c_str());
|
||||||
|
}
|
||||||
|
if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
|
||||||
|
state = TRANSFORM_PARAM_NAME;
|
||||||
|
} else {
|
||||||
|
// Add a transform with no parameters.
|
||||||
|
params_list->push_back({transform_name.ToString(), func_parameters});
|
||||||
|
transform_name = "";
|
||||||
|
state = TRANSFORM_NAME;
|
||||||
|
}
|
||||||
|
} else if (state == TRANSFORM_PARAM_NAME) {
|
||||||
|
if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
|
||||||
|
params_list->push_back({transform_name.ToString(), func_parameters});
|
||||||
|
transform_name = "";
|
||||||
|
state = TRANSFORM_NAME;
|
||||||
|
} else {
|
||||||
|
// Eat up any leading spaces or commas.
|
||||||
|
Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match);
|
||||||
|
Scanner(remaining).Any(Scanner::SPACE).GetResult(&remaining, &match);
|
||||||
|
// See if we have a valid parameter name.
|
||||||
|
const bool found_parameter_name =
|
||||||
|
Scanner(remaining)
|
||||||
|
.Any(Scanner::LETTER_DIGIT_UNDERSCORE)
|
||||||
|
.GetResult(&remaining, ¶meter_name);
|
||||||
|
if (!found_parameter_name) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Looking for parameter name, but found ",
|
||||||
|
remaining.ToString().c_str());
|
||||||
|
}
|
||||||
|
if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
|
||||||
|
state = TRANSFORM_PARAM_VALUE;
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Looking for =, but found ",
|
||||||
|
remaining.ToString().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (state == TRANSFORM_PARAM_VALUE) {
|
||||||
|
bool found_parameter_value;
|
||||||
|
// Deal with quoted values.
|
||||||
|
if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) {
|
||||||
|
found_parameter_value =
|
||||||
|
Scanner(remaining).ScanEscapedUntil('"').GetResult(
|
||||||
|
&remaining, ¶meter_value);
|
||||||
|
if (found_parameter_value) {
|
||||||
|
Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// See if we have a valid parameter name.
|
||||||
|
found_parameter_value =
|
||||||
|
Scanner(remaining)
|
||||||
|
.Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
|
||||||
|
.GetResult(&remaining, ¶meter_value);
|
||||||
|
}
|
||||||
|
if (!found_parameter_value) {
|
||||||
|
return errors::InvalidArgument("Looking for parameter name, but found ",
|
||||||
|
remaining.ToString().c_str());
|
||||||
|
}
|
||||||
|
func_parameters[parameter_name.ToString()].push_back(
|
||||||
|
parameter_value.ToString());
|
||||||
|
// Eat up any trailing quotes.
|
||||||
|
Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
|
||||||
|
Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
|
||||||
|
state = TRANSFORM_PARAM_NAME;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
||||||
|
string in_graph = "";
|
||||||
|
string out_graph = "";
|
||||||
|
string inputs_string = "";
|
||||||
|
string outputs_string = "";
|
||||||
|
string transforms_string = "";
|
||||||
|
std::vector<Flag> flag_list = {
|
||||||
|
Flag("in_graph", &in_graph, "input graph file name"),
|
||||||
|
Flag("out_graph", &out_graph, "output graph file name"),
|
||||||
|
Flag("inputs", &inputs_string, "inputs"),
|
||||||
|
Flag("outputs", &outputs_string, "outputs"),
|
||||||
|
Flag("transforms", &transforms_string, "list of transforms"),
|
||||||
|
};
|
||||||
|
string usage = Flags::Usage(argv[0], flag_list);
|
||||||
|
usage += "\nTransforms are:\n";
|
||||||
|
TransformRegistry* transform_registry = GetTransformRegistry();
|
||||||
|
for (const auto& pair : *transform_registry) {
|
||||||
|
usage += pair.first + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
|
||||||
|
// We need to call this to set up global state for TensorFlow.
|
||||||
|
if (init_main) {
|
||||||
|
port::InitMain(argv[0], &argc, &argv);
|
||||||
|
}
|
||||||
|
if (!parse_result) {
|
||||||
|
LOG(ERROR) << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (argc > 1) {
|
||||||
|
LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (in_graph.empty()) {
|
||||||
|
LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (out_graph.empty()) {
|
||||||
|
LOG(ERROR) << "out_graph graph can't be empty.\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (transforms_string.empty()) {
|
||||||
|
LOG(ERROR) << "You must specify at least one transform.\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> inputs = str_util::Split(inputs_string, ',');
|
||||||
|
std::vector<string> outputs = str_util::Split(outputs_string, ',');
|
||||||
|
TransformParameters transform_params;
|
||||||
|
Status parse_status =
|
||||||
|
ParseTransformParameters(transforms_string, &transform_params);
|
||||||
|
if (!parse_status.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to parse --transform argument, error was "
|
||||||
|
<< parse_status.error_message();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (transform_params.empty()) {
|
||||||
|
LOG(ERROR) << "You must specify at least one transform.\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
|
||||||
|
if (!load_status.ok()) {
|
||||||
|
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
|
||||||
|
<< load_status.error_message();
|
||||||
|
LOG(ERROR) << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status transform_result =
|
||||||
|
TransformGraph(inputs, outputs, transform_params, &graph_def);
|
||||||
|
|
||||||
|
if (!transform_result.ok()) {
|
||||||
|
LOG(ERROR) << transform_result.error_message();
|
||||||
|
LOG(ERROR) << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
|
||||||
|
if (!save_status.ok()) {
|
||||||
|
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
|
||||||
|
<< save_status.error_message();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
|
||||||
|
bool* ignore_errors) {
|
||||||
|
*ignore_errors = false;
|
||||||
|
if (transform_params.count("ignore_errors") &&
|
||||||
|
(!transform_params.at("ignore_errors").empty())) {
|
||||||
|
const string& ignore_errors_string =
|
||||||
|
str_util::Lowercase(transform_params.at("ignore_errors").at(0));
|
||||||
|
if (ignore_errors_string == "true") {
|
||||||
|
*ignore_errors = true;
|
||||||
|
} else if (ignore_errors_string == "false") {
|
||||||
|
*ignore_errors = false;
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"ignore_errors should be true or false, found ",
|
||||||
|
ignore_errors_string);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TransformGraph(const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs,
|
||||||
|
const TransformParameters& transform_params,
|
||||||
|
GraphDef* graph_def) {
|
||||||
|
TransformRegistry* transform_registry = GetTransformRegistry();
|
||||||
|
for (const auto& transform_info : transform_params) {
|
||||||
|
const string& transform_name = transform_info.first;
|
||||||
|
if (transform_name == "") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!transform_registry->count(transform_name)) {
|
||||||
|
return errors::InvalidArgument("Transform '", transform_name,
|
||||||
|
"' not recognized.");
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Applying " << transform_name;
|
||||||
|
const TransformFunc& transform_func =
|
||||||
|
transform_registry->at(transform_name);
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = inputs;
|
||||||
|
context.output_names = outputs;
|
||||||
|
context.params = transform_info.second;
|
||||||
|
bool ignore_errors;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ShouldIgnoreErrors(transform_info.second, &ignore_errors));
|
||||||
|
GraphDef transformed_graph_def;
|
||||||
|
Status transform_result =
|
||||||
|
transform_func(*graph_def, context, &transformed_graph_def);
|
||||||
|
if (!transform_result.ok()) {
|
||||||
|
if (ignore_errors) {
|
||||||
|
LOG(ERROR) << transform_name << ": Ignoring error "
|
||||||
|
<< transform_result.error_message();
|
||||||
|
transformed_graph_def = *graph_def;
|
||||||
|
} else {
|
||||||
|
return transform_result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Copy over the library from the original input graph.
|
||||||
|
transformed_graph_def.mutable_library()->CopyFrom(graph_def->library());
|
||||||
|
TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def));
|
||||||
|
|
||||||
|
*graph_def = transformed_graph_def;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
50
tensorflow/tools/graph_transforms/transform_graph.h
Normal file
50
tensorflow/tools/graph_transforms/transform_graph.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_
|
||||||
|
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Convenience function to handle argument parsing for the command line tool.
|
||||||
|
// If init_main is false, we're testing so don't call core initialization.
|
||||||
|
int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main);
|
||||||
|
|
||||||
|
// Handles converting the transforms string into transform names and their
|
||||||
|
// arguments.
|
||||||
|
typedef std::vector<std::pair<string, TransformFuncParameters>>
|
||||||
|
TransformParameters;
|
||||||
|
Status ParseTransformParameters(const string& transforms_string,
|
||||||
|
TransformParameters* params_list);
|
||||||
|
|
||||||
|
// Applies a series of transformations to the GraphDef. These transforms are
|
||||||
|
// defined by modules that call REGISTER_GRAPH_TRANSFORM() to associate a
|
||||||
|
// function with a name string.
|
||||||
|
Status TransformGraph(const std::vector<string>& inputs,
|
||||||
|
const std::vector<string>& outputs,
|
||||||
|
const TransformParameters& transform_params,
|
||||||
|
GraphDef* graph_def);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_GRAPH_H_
|
53
tensorflow/tools/graph_transforms/transform_graph_main.cc
Normal file
53
tensorflow/tools/graph_transforms/transform_graph_main.cc
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Tool that applies a series of transformations to a frozen GraphDef file.
|
||||||
|
// It takes a flexible list of transforms either on the command line, and runs
|
||||||
|
// those on the incoming graph to produce the result. This allows you to build a
|
||||||
|
// processing pipeline when preparing models for deployment.
|
||||||
|
//
|
||||||
|
// bazel build tensorflow/tools/graph_transforms/fold_constants_tool &&
|
||||||
|
// bazel-bin/tensorflow/tools/graph_transforms/fold_constants_tool \
|
||||||
|
// --in_graph=graph_def.pb \
|
||||||
|
// --out_graph=transformed_graph_def.pb \
|
||||||
|
// --inputs=input1,input2 \
|
||||||
|
// --outputs=output1,output2 \
|
||||||
|
// --transforms="fold_constants order_nodes"
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// in_graph - name of a file with a frozen GraphDef proto in binary format.
|
||||||
|
// out_graph - name of the output file to save the transformed version to.
|
||||||
|
// inputs - layer names of the nodes that will be fed data.
|
||||||
|
// outputs - layer names of the nodes that will be read from after running.
|
||||||
|
// transforms - space-separated names of the transforms to apply.
|
||||||
|
//
|
||||||
|
// List of implemented transforms:
|
||||||
|
// fold_constants - Merges constant expression subgraphs into single constants,
|
||||||
|
// which can help reduce the number of ops and make subsequent transforms
|
||||||
|
// optimizations more effective.
|
||||||
|
// order_nodes - Sorts the GraphDef nodes in execution order, which can help
|
||||||
|
// simple inference engines that want to avoid complexity in their executors.
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
return tensorflow::graph_transforms::ParseFlagsAndTransformGraph(argc, argv,
|
||||||
|
true);
|
||||||
|
}
|
228
tensorflow/tools/graph_transforms/transform_graph_test.cc
Normal file
228
tensorflow/tools/graph_transforms/transform_graph_test.cc
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_graph.h"
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declared here so we don't have to expose it in the public header.
|
||||||
|
Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
|
||||||
|
bool* ignore_errors);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
Status test_empty_graph_transform(const GraphDef& graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* result) {
|
||||||
|
result->Clear();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform",
|
||||||
|
test_empty_graph_transform);
|
||||||
|
|
||||||
|
class TransformGraphTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestConstantFolding() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 100;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const =
|
||||||
|
Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const =
|
||||||
|
Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const);
|
||||||
|
|
||||||
|
Output placeholder =
|
||||||
|
Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul =
|
||||||
|
Mul(root.WithOpName("output_expect_remains"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
string graph_def_serialized;
|
||||||
|
graph_def.SerializeToString(&graph_def_serialized);
|
||||||
|
const string dir = testing::TmpDir();
|
||||||
|
const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb");
|
||||||
|
const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb");
|
||||||
|
TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb,
|
||||||
|
graph_def_serialized));
|
||||||
|
|
||||||
|
std::vector<string> args = {"some_binary",
|
||||||
|
"--in_graph=" + in_filename_pb,
|
||||||
|
"--out_graph=" + out_filename_pb,
|
||||||
|
"--inputs=placeholder_expect_remains",
|
||||||
|
"--outputs=output_expect_remains",
|
||||||
|
"--transforms=fold_constants"};
|
||||||
|
const int argc = 6;
|
||||||
|
EXPECT_EQ(argc, args.size());
|
||||||
|
char* argv[argc];
|
||||||
|
std::vector<char*> char_strings;
|
||||||
|
for (int i = 0; i < argc; ++i) {
|
||||||
|
string arg = args[i];
|
||||||
|
char* char_string = new char[arg.size() + 1];
|
||||||
|
std::copy_n(arg.c_str(), arg.size() + 1, char_string);
|
||||||
|
argv[i] = char_string;
|
||||||
|
char_strings.push_back(char_string);
|
||||||
|
}
|
||||||
|
ParseFlagsAndTransformGraph(argc, argv, false);
|
||||||
|
for (char* char_string : char_strings) {
|
||||||
|
delete[] char_string;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef out_graph_def;
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> out_node_map;
|
||||||
|
graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map);
|
||||||
|
|
||||||
|
for (const NodeDef& node : out_graph_def.node()) {
|
||||||
|
const StringPiece name(node.name());
|
||||||
|
const int occurrence_count = out_node_map.count(node.name());
|
||||||
|
if (name.ends_with("expect_removed")) {
|
||||||
|
EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name();
|
||||||
|
}
|
||||||
|
if (name.ends_with("expect_remains")) {
|
||||||
|
EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestTransformRegistration() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
Output placeholder =
|
||||||
|
Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
EXPECT_EQ(1, graph_def.node().size());
|
||||||
|
TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}},
|
||||||
|
&graph_def));
|
||||||
|
EXPECT_EQ(0, graph_def.node().size());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
Status no_such_status =
|
||||||
|
TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
StringPiece(no_such_status.ToString()).contains("not recognized"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestParseTransformParameters() {
|
||||||
|
TransformParameters params_list;
|
||||||
|
|
||||||
|
ParseTransformParameters("foo", ¶ms_list);
|
||||||
|
EXPECT_EQ(1, params_list.size());
|
||||||
|
EXPECT_EQ("foo", params_list[0].first);
|
||||||
|
EXPECT_TRUE(params_list[0].second.empty());
|
||||||
|
|
||||||
|
ParseTransformParameters("foo bar", ¶ms_list);
|
||||||
|
EXPECT_EQ(2, params_list.size());
|
||||||
|
EXPECT_EQ("foo", params_list[0].first);
|
||||||
|
EXPECT_TRUE(params_list[0].second.empty());
|
||||||
|
EXPECT_EQ("bar", params_list[1].first);
|
||||||
|
EXPECT_TRUE(params_list[1].second.empty());
|
||||||
|
|
||||||
|
ParseTransformParameters("foo() bar()", ¶ms_list);
|
||||||
|
EXPECT_EQ(2, params_list.size());
|
||||||
|
EXPECT_EQ("foo", params_list[0].first);
|
||||||
|
EXPECT_TRUE(params_list[0].second.empty());
|
||||||
|
EXPECT_EQ("bar", params_list[1].first);
|
||||||
|
EXPECT_TRUE(params_list[1].second.empty());
|
||||||
|
|
||||||
|
ParseTransformParameters("foo(bob_something=sue)", ¶ms_list);
|
||||||
|
EXPECT_EQ(1, params_list.size());
|
||||||
|
EXPECT_EQ("foo", params_list[0].first);
|
||||||
|
EXPECT_EQ(1, params_list[0].second.count("bob_something"));
|
||||||
|
EXPECT_EQ(1, params_list[0].second["bob_something"].size());
|
||||||
|
EXPECT_EQ("sue", params_list[0].second["bob_something"][0]);
|
||||||
|
|
||||||
|
ParseTransformParameters("bar(a=1, b=2, a=3)", ¶ms_list);
|
||||||
|
EXPECT_EQ(1, params_list.size());
|
||||||
|
EXPECT_EQ("bar", params_list[0].first);
|
||||||
|
EXPECT_EQ(1, params_list[0].second.count("a"));
|
||||||
|
EXPECT_EQ(2, params_list[0].second["a"].size());
|
||||||
|
EXPECT_EQ("1", params_list[0].second["a"][0]);
|
||||||
|
EXPECT_EQ("3", params_list[0].second["a"][1]);
|
||||||
|
EXPECT_EQ(1, params_list[0].second.count("b"));
|
||||||
|
EXPECT_EQ(1, params_list[0].second["b"].size());
|
||||||
|
EXPECT_EQ("2", params_list[0].second["b"][0]);
|
||||||
|
|
||||||
|
ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)", ¶ms_list);
|
||||||
|
EXPECT_EQ(1, params_list.size());
|
||||||
|
EXPECT_EQ("bar", params_list[0].first);
|
||||||
|
EXPECT_EQ(1, params_list[0].second.count("a"));
|
||||||
|
EXPECT_EQ(2, params_list[0].second["a"].size());
|
||||||
|
EXPECT_EQ("1", params_list[0].second["a"][0]);
|
||||||
|
EXPECT_EQ("3", params_list[0].second["a"][1]);
|
||||||
|
EXPECT_EQ(1, params_list[0].second.count("b"));
|
||||||
|
EXPECT_EQ(1, params_list[0].second["b"].size());
|
||||||
|
EXPECT_EQ("1,2,3", params_list[0].second["b"][0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestShouldIgnoreErrors() {
|
||||||
|
bool ignore_errors;
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors));
|
||||||
|
EXPECT_TRUE(ignore_errors);
|
||||||
|
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors));
|
||||||
|
EXPECT_FALSE(ignore_errors);
|
||||||
|
|
||||||
|
TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors));
|
||||||
|
EXPECT_FALSE(ignore_errors);
|
||||||
|
|
||||||
|
EXPECT_FALSE(
|
||||||
|
ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); }
|
||||||
|
|
||||||
|
TEST_F(TransformGraphTest, TestTransformRegistration) {
|
||||||
|
TestTransformRegistration();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformGraphTest, TestParseTransformParameters) {
|
||||||
|
TestParseTransformParameters();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); }
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
@ -15,12 +15,50 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/public/session.h"
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace graph_transforms {
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
inline bool IsMerge(const NodeDef& node_def) {
|
||||||
|
return node_def.op() == "Merge" || node_def.op() == "RefMerge";
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordMatchedNodes(const NodeMatch& match,
|
||||||
|
std::set<string>* matched_nodes) {
|
||||||
|
matched_nodes->insert(match.node.name());
|
||||||
|
for (const NodeMatch& input_match : match.inputs) {
|
||||||
|
RecordMatchedNodes(input_match, matched_nodes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint64 Hash64String(const string& input) {
|
||||||
|
return Hash64(input.data(), input.size());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
|
||||||
|
std::set<string> found_nodes;
|
||||||
|
std::vector<NodeMatch> current_matches = {match};
|
||||||
|
while (!current_matches.empty()) {
|
||||||
|
std::vector<NodeMatch> next_matches;
|
||||||
|
for (const NodeMatch& current_match : current_matches) {
|
||||||
|
if (found_nodes.count(current_match.node.name())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
found_nodes.insert(current_match.node.name());
|
||||||
|
result->push_back(current_match.node);
|
||||||
|
for (const NodeMatch& input_match : current_match.inputs) {
|
||||||
|
next_matches.push_back(input_match);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current_matches = next_matches;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void MapNamesToNodes(const GraphDef& graph_def,
|
void MapNamesToNodes(const GraphDef& graph_def,
|
||||||
std::map<string, const NodeDef*>* result) {
|
std::map<string, const NodeDef*>* result) {
|
||||||
for (const NodeDef& node : graph_def.node()) {
|
for (const NodeDef& node : graph_def.node()) {
|
||||||
@ -28,7 +66,19 @@ void MapNamesToNodes(const GraphDef& graph_def,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeNamePartsFromInput(string input_name, string* prefix,
|
void MapNodesToOutputs(const GraphDef& graph_def,
|
||||||
|
std::map<string, std::vector<const NodeDef*>>* result) {
|
||||||
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
MapNamesToNodes(graph_def, &node_map);
|
||||||
|
for (const NodeDef& node : graph_def.node()) {
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
string input_node_name = NodeNameFromInput(input);
|
||||||
|
(*result)[input_node_name].push_back(&node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NodeNamePartsFromInput(const string& input_name, string* prefix,
|
||||||
string* node_name, string* suffix) {
|
string* node_name, string* suffix) {
|
||||||
std::vector<string> input_parts = str_util::Split(input_name, ':');
|
std::vector<string> input_parts = str_util::Split(input_name, ':');
|
||||||
if (input_parts.size() < 2) {
|
if (input_parts.size() < 2) {
|
||||||
@ -45,7 +95,7 @@ void NodeNamePartsFromInput(string input_name, string* prefix,
|
|||||||
*node_name = node_name_piece.ToString();
|
*node_name = node_name_piece.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string NodeNameFromInput(string input_name) {
|
string NodeNameFromInput(const string& input_name) {
|
||||||
string prefix;
|
string prefix;
|
||||||
string node_name;
|
string node_name;
|
||||||
string suffix;
|
string suffix;
|
||||||
@ -53,6 +103,57 @@ string NodeNameFromInput(string input_name) {
|
|||||||
return node_name;
|
return node_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string CanonicalInputName(const string& input_name) {
|
||||||
|
string prefix;
|
||||||
|
string node_name;
|
||||||
|
string suffix;
|
||||||
|
NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
|
||||||
|
if (suffix == "") {
|
||||||
|
suffix = ":0";
|
||||||
|
}
|
||||||
|
return prefix + node_name + suffix;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64 HashNodeDef(const NodeDef& node) {
|
||||||
|
uint64 hash = Hash64String(node.op());
|
||||||
|
hash = Hash64Combine(hash, Hash64String(node.name()));
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
|
||||||
|
}
|
||||||
|
hash = Hash64Combine(hash, Hash64String(node.device()));
|
||||||
|
std::vector<string> attr_names;
|
||||||
|
attr_names.reserve(node.attr().size());
|
||||||
|
for (const auto& attr : node.attr()) {
|
||||||
|
attr_names.push_back(attr.first);
|
||||||
|
}
|
||||||
|
std::sort(attr_names.begin(), attr_names.end());
|
||||||
|
string attr_serialized;
|
||||||
|
for (const string& attr_name : attr_names) {
|
||||||
|
auto attr = node.attr().at(attr_name);
|
||||||
|
attr.SerializeToString(&attr_serialized);
|
||||||
|
hash = Hash64Combine(hash, Hash64String(attr_serialized));
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddNodeInput(const string& input_name, NodeDef* node) {
|
||||||
|
*(node->mutable_input()->Add()) = input_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyNodeAttr(const NodeDef& source, const string& source_key,
|
||||||
|
const string& dest_key, NodeDef* dest) {
|
||||||
|
CHECK_NE(0, source.attr().count(source_key))
|
||||||
|
<< "No key '" << source_key << "' found in " << source.DebugString();
|
||||||
|
(*(dest->mutable_attr()))[dest_key].CopyFrom(source.attr().at(source_key));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
|
||||||
|
TensorProto tensor_proto = node.attr().at(key).tensor();
|
||||||
|
Tensor tensor;
|
||||||
|
CHECK(tensor.FromProto(tensor_proto));
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
void FilterGraphDef(const GraphDef& input_graph_def,
|
void FilterGraphDef(const GraphDef& input_graph_def,
|
||||||
std::function<bool(const NodeDef&)> selector,
|
std::function<bool(const NodeDef&)> selector,
|
||||||
GraphDef* output_graph_def) {
|
GraphDef* output_graph_def) {
|
||||||
@ -77,5 +178,425 @@ void RemoveAttributes(const GraphDef& input_graph_def,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status SortByExecutionOrder(const GraphDef& input_graph_def,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
const int num_nodes = input_graph_def.node_size();
|
||||||
|
std::vector<int> ready;
|
||||||
|
std::vector<int> pending_count;
|
||||||
|
pending_count.reserve(num_nodes);
|
||||||
|
std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
|
||||||
|
|
||||||
|
std::map<string, int> name_index;
|
||||||
|
for (int i = 0; i < input_graph_def.node_size(); ++i) {
|
||||||
|
const NodeDef& node(input_graph_def.node(i));
|
||||||
|
name_index[node.name()] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the inputs for each node.
|
||||||
|
for (int n = 0; n < num_nodes; ++n) {
|
||||||
|
const NodeDef& node_def(input_graph_def.node(n));
|
||||||
|
if (IsMerge(node_def)) {
|
||||||
|
// for merge only wait for one non-control input.
|
||||||
|
int32 num_control_edges = 0;
|
||||||
|
for (int i = 0; i < node_def.input_size(); ++i) {
|
||||||
|
StringPiece input_name(node_def.input(i));
|
||||||
|
if (input_name.starts_with("^")) {
|
||||||
|
num_control_edges++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pending_count.push_back(num_control_edges + 1);
|
||||||
|
} else {
|
||||||
|
pending_count.push_back(node_def.input_size());
|
||||||
|
}
|
||||||
|
if (node_def.input_size() == 0) {
|
||||||
|
ready.push_back(n);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < node_def.input_size(); ++i) {
|
||||||
|
const string& input_name = node_def.input(i);
|
||||||
|
const string& input_node_name = NodeNameFromInput(input_name);
|
||||||
|
if (!name_index.count(input_node_name)) {
|
||||||
|
return errors::InvalidArgument("Node '", node_def.name(),
|
||||||
|
"': Unknown input node '",
|
||||||
|
node_def.input(i), "'");
|
||||||
|
}
|
||||||
|
outputs[name_index[input_node_name]].push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int processed = 0;
|
||||||
|
output_graph_def->Clear();
|
||||||
|
// Process the NodeDefs in topological order.
|
||||||
|
// Code above sets this up by filling in ready_ with nodes that have no
|
||||||
|
// inputs, pending_counts_ with the number of inputs for each node and
|
||||||
|
// outputs_ with the outputs of each node.
|
||||||
|
while (!ready.empty()) {
|
||||||
|
int o = ready.back();
|
||||||
|
ready.pop_back();
|
||||||
|
++processed;
|
||||||
|
const NodeDef& node_def(input_graph_def.node(o));
|
||||||
|
output_graph_def->mutable_node()->Add()->CopyFrom(node_def);
|
||||||
|
|
||||||
|
// Update pending_count for outputs.
|
||||||
|
for (size_t i = 0; i < outputs[o].size(); ++i) {
|
||||||
|
const int output = outputs[o][i];
|
||||||
|
pending_count[output]--;
|
||||||
|
if (pending_count[output] == 0) {
|
||||||
|
ready.push_back(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processed < input_graph_def.node_size()) {
|
||||||
|
return errors::InvalidArgument(input_graph_def.node_size() - processed,
|
||||||
|
" nodes in a cycle");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
string OpTypePattern::DebugString() const {
|
||||||
|
string result = "{" + op + ", {";
|
||||||
|
for (const OpTypePattern& input : inputs) {
|
||||||
|
result += input.DebugString() + ",";
|
||||||
|
}
|
||||||
|
result += "}}";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
string NodeMatch::DebugString() const {
|
||||||
|
string result = "{";
|
||||||
|
result += node.DebugString();
|
||||||
|
result += ", {";
|
||||||
|
for (const NodeMatch& input : inputs) {
|
||||||
|
result += input.DebugString() + ",";
|
||||||
|
}
|
||||||
|
result += "}}";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
|
||||||
|
SortByExecutionOrder(graph_def, &graph_def_);
|
||||||
|
MapNamesToNodes(graph_def_, &node_map_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
|
||||||
|
std::vector<NodeMatch>* matches) {
|
||||||
|
std::set<string> matched_nodes;
|
||||||
|
for (const NodeDef& node : graph_def_.node()) {
|
||||||
|
// Skip any nodes that are already part of a match.
|
||||||
|
if (matched_nodes.count(node.name())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
NodeMatch match;
|
||||||
|
if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
|
||||||
|
RecordMatchedNodes(match, &matched_nodes);
|
||||||
|
matches->push_back(match);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphMatcher::DoesOpTypeMatch(
|
||||||
|
const NodeDef& node, const OpTypePattern& pattern,
|
||||||
|
const std::set<string>& previously_matched_nodes, NodeMatch* match) {
|
||||||
|
VLOG(1) << "Looking at node " << node.DebugString();
|
||||||
|
VLOG(1) << "pattern=" << pattern.DebugString();
|
||||||
|
VLOG(1) << "match=" << match->DebugString();
|
||||||
|
if (previously_matched_nodes.count(node.name())) {
|
||||||
|
VLOG(1) << "node " << node.name() << " has been previously matched";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool pattern_matched = false;
|
||||||
|
if (pattern.op == "*") {
|
||||||
|
pattern_matched = true;
|
||||||
|
} else {
|
||||||
|
std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
|
||||||
|
for (const string& pattern_op : pattern_ops) {
|
||||||
|
if (node.op() == pattern_op) {
|
||||||
|
pattern_matched = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!pattern_matched) {
|
||||||
|
VLOG(1) << "node.op() != pattern.op()";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match->node = node;
|
||||||
|
// Ignore any control inputs for pattern-matching purposes
|
||||||
|
std::vector<string> non_control_inputs;
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
if (!input.empty() && (input[0] != '^')) {
|
||||||
|
non_control_inputs.push_back(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pattern.inputs.empty()) {
|
||||||
|
// If there are no inputs, assume that's the end of the pattern.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (non_control_inputs.size() != pattern.inputs.size()) {
|
||||||
|
VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < pattern.inputs.size(); ++i) {
|
||||||
|
const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
|
||||||
|
const NodeDef& input_node = *(node_map_[input_node_name]);
|
||||||
|
const OpTypePattern& input_pattern = pattern.inputs[i];
|
||||||
|
match->inputs.push_back(NodeMatch());
|
||||||
|
NodeMatch* input_match = &(match->inputs.back());
|
||||||
|
if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
|
||||||
|
input_match)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReplaceMatchingOpTypes(
|
||||||
|
const GraphDef& input_graph_def, const OpTypePattern& pattern,
|
||||||
|
const std::function<Status(const NodeMatch&, const std::set<string>&,
|
||||||
|
const std::set<string>&, std::vector<NodeDef>*)>&
|
||||||
|
node_generator,
|
||||||
|
const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
|
||||||
|
// Start off by retrieving all the matching subgraphs.
|
||||||
|
GraphMatcher matcher(input_graph_def);
|
||||||
|
std::vector<NodeMatch> matches;
|
||||||
|
matcher.GetOpTypeMatches(pattern, &matches);
|
||||||
|
|
||||||
|
// Do some housekeeping so we can easily look up the resulting matches given
|
||||||
|
// a node name.
|
||||||
|
std::set<string> matched_nodes;
|
||||||
|
std::map<string, const NodeMatch*> matches_by_head_name;
|
||||||
|
for (const NodeMatch& match : matches) {
|
||||||
|
matches_by_head_name[match.node.name()] = &match;
|
||||||
|
RecordMatchedNodes(match, &matched_nodes);
|
||||||
|
}
|
||||||
|
std::map<string, std::vector<const NodeDef*>> outputs_map;
|
||||||
|
MapNodesToOutputs(input_graph_def, &outputs_map);
|
||||||
|
|
||||||
|
// Go through all the nodes in the input graph, see if they are part of a
|
||||||
|
// match or if they can be left untouched.
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& input_node : input_graph_def.node()) {
|
||||||
|
if (matches_by_head_name.count(input_node.name())) {
|
||||||
|
// This node is the beginning of a match, so call the replacement function
|
||||||
|
// after setting up some information it will need.
|
||||||
|
const NodeMatch* match = matches_by_head_name[input_node.name()];
|
||||||
|
std::vector<NodeDef> matched_nodes_array;
|
||||||
|
MatchedNodesAsArray(*match, &matched_nodes_array);
|
||||||
|
// This tells us whether a node is part of the current match.
|
||||||
|
std::set<string> matched_nodes_lookup;
|
||||||
|
for (const NodeDef& matched_node : matched_nodes_array) {
|
||||||
|
matched_nodes_lookup.insert(matched_node.name());
|
||||||
|
}
|
||||||
|
// These are helper arrays that the replacement function can use to tell
|
||||||
|
// whether it can safely remove an internal node (because nothing outside
|
||||||
|
// of the match uses it) or whether external nodes depend on it.
|
||||||
|
std::set<string> input_nodes;
|
||||||
|
std::set<string> output_nodes;
|
||||||
|
for (const NodeDef& matched_node : matched_nodes_array) {
|
||||||
|
// Look through all of this node's inputs, and if any of them come from
|
||||||
|
// outside the match, then this should be noted as one of the external
|
||||||
|
// inputs of the subgraph.
|
||||||
|
for (const string& input_name : matched_node.input()) {
|
||||||
|
string input_node_name = NodeNameFromInput(input_name);
|
||||||
|
if (!matched_nodes_lookup.count(input_node_name)) {
|
||||||
|
input_nodes.insert(matched_node.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Do a reverse input lookup, to see which other nodes use the current
|
||||||
|
// one as an input. If any of those nodes are outside the match
|
||||||
|
// subgraph, then the current node is marked as an output node that
|
||||||
|
// shouldn't be removed.
|
||||||
|
if (outputs_map.count(matched_node.name())) {
|
||||||
|
for (const NodeDef* dependent_node :
|
||||||
|
outputs_map[matched_node.name()]) {
|
||||||
|
if (!matched_nodes_lookup.count(dependent_node->name())) {
|
||||||
|
output_nodes.insert(matched_node.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Call the generator function and add all the returned nodes to the
|
||||||
|
// graph.
|
||||||
|
std::vector<NodeDef> new_nodes;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
node_generator(*match, input_nodes, output_nodes, &new_nodes));
|
||||||
|
std::set<string> new_node_names;
|
||||||
|
for (const NodeDef& new_node : new_nodes) {
|
||||||
|
new_node_names.insert(new_node.name());
|
||||||
|
}
|
||||||
|
// Check to make sure the generator function preserved all of the nodes
|
||||||
|
// that are used elsewhere in the graph, and add them back in if not.
|
||||||
|
bool abort_replacement = false;
|
||||||
|
if (!options.allow_inconsistencies) {
|
||||||
|
for (const string& expected_output : output_nodes) {
|
||||||
|
if (!new_node_names.count(expected_output)) {
|
||||||
|
LOG(WARNING) << "Expected " << expected_output
|
||||||
|
<< " to be preserved.";
|
||||||
|
abort_replacement = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (abort_replacement) {
|
||||||
|
LOG(WARNING) << "Generator function didn't preserve needed nodes, "
|
||||||
|
<< "copying old replacements back in instead.";
|
||||||
|
std::vector<NodeDef> old_nodes;
|
||||||
|
MatchedNodesAsArray(*match, &old_nodes);
|
||||||
|
for (const NodeDef& old_node : old_nodes) {
|
||||||
|
NodeDef* added_node = output_graph_def->mutable_node()->Add();
|
||||||
|
added_node->CopyFrom(old_node);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (const NodeDef& new_node : new_nodes) {
|
||||||
|
NodeDef* added_node = output_graph_def->mutable_node()->Add();
|
||||||
|
added_node->CopyFrom(new_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (!matched_nodes.count(input_node.name())) {
|
||||||
|
// This node isn't part of any match, so just copy it over.
|
||||||
|
NodeDef* added_node = output_graph_def->mutable_node()->Add();
|
||||||
|
added_node->CopyFrom(input_node);
|
||||||
|
} else {
|
||||||
|
// Do nothing, because this is an internal part of a matching subgraph,
|
||||||
|
// and so will have been replaced by a new replacement subgraph.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RenameNodeInputs(const GraphDef& input_graph_def,
|
||||||
|
const std::map<string, string>& inputs_to_rename,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
std::map<string, std::vector<std::pair<string, string>>>
|
||||||
|
canonical_inputs_to_rename;
|
||||||
|
for (const auto& input_to_rename : inputs_to_rename) {
|
||||||
|
canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
|
||||||
|
.push_back({input_to_rename.first, input_to_rename.second});
|
||||||
|
}
|
||||||
|
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
new_node->mutable_input()->Clear();
|
||||||
|
for (const string& input_name : node.input()) {
|
||||||
|
std::set<string> already_visited;
|
||||||
|
string new_input_name = input_name;
|
||||||
|
while (
|
||||||
|
canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
|
||||||
|
string input_node_name = NodeNameFromInput(new_input_name);
|
||||||
|
if (already_visited.count(input_node_name)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"RenameNodeInputs argument contains a cycle for ",
|
||||||
|
input_node_name);
|
||||||
|
}
|
||||||
|
already_visited.insert(input_node_name);
|
||||||
|
bool any_match_found = false;
|
||||||
|
for (const std::pair<string, string>& input_to_rename :
|
||||||
|
canonical_inputs_to_rename.at(input_node_name)) {
|
||||||
|
const string& source_name = input_to_rename.first;
|
||||||
|
const string& dest_name = input_to_rename.second;
|
||||||
|
bool is_match;
|
||||||
|
string match_name;
|
||||||
|
if (StringPiece(source_name).ends_with(":*")) {
|
||||||
|
is_match = true;
|
||||||
|
string prefix;
|
||||||
|
string unused_node_name;
|
||||||
|
string suffix;
|
||||||
|
NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
|
||||||
|
&suffix);
|
||||||
|
match_name = prefix + dest_name + suffix;
|
||||||
|
} else {
|
||||||
|
is_match = (CanonicalInputName(source_name) ==
|
||||||
|
CanonicalInputName(new_input_name));
|
||||||
|
match_name = dest_name;
|
||||||
|
}
|
||||||
|
if (is_match) {
|
||||||
|
new_input_name = match_name;
|
||||||
|
any_match_found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!any_match_found) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*(new_node->mutable_input()->Add()) = new_input_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyOriginalMatch(const NodeMatch& match,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
std::vector<NodeDef> old_nodes;
|
||||||
|
MatchedNodesAsArray(match, &old_nodes);
|
||||||
|
for (const NodeDef& old_node : old_nodes) {
|
||||||
|
new_nodes->push_back(old_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TransformRegistry* GetTransformRegistry() {
|
||||||
|
static TransformRegistry transform_registry;
|
||||||
|
return &transform_registry;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FindInvalidInputs(const GraphDef& graph_def,
|
||||||
|
std::vector<std::pair<string, string>>* invalid_inputs) {
|
||||||
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
MapNamesToNodes(graph_def, &node_map);
|
||||||
|
|
||||||
|
for (const NodeDef& node : graph_def.node()) {
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
string input_node = NodeNameFromInput(input);
|
||||||
|
if (!node_map.count(input_node)) {
|
||||||
|
invalid_inputs->push_back({node.name(), input_node});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status IsGraphValid(const GraphDef& graph_def) {
|
||||||
|
std::vector<std::pair<string, string>> invalid_inputs;
|
||||||
|
FindInvalidInputs(graph_def, &invalid_inputs);
|
||||||
|
if (!invalid_inputs.empty()) {
|
||||||
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
MapNamesToNodes(graph_def, &node_map);
|
||||||
|
for (const std::pair<string, string>& invalid_input : invalid_inputs) {
|
||||||
|
LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
|
||||||
|
<< invalid_input.first << " - "
|
||||||
|
<< node_map[invalid_input.first]->DebugString();
|
||||||
|
}
|
||||||
|
return errors::Internal(
|
||||||
|
"Invalid graph with inputs referring to nonexistent nodes");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int CountParameters(const TransformFuncContext& context, const string& name) {
|
||||||
|
if (context.params.count(name)) {
|
||||||
|
return context.params.at(name).size();
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetExactlyOneParameter(const TransformFuncContext& context,
|
||||||
|
const string& name, const string& default_value,
|
||||||
|
string* result) {
|
||||||
|
const int params_count = CountParameters(context, name);
|
||||||
|
if (params_count == 0) {
|
||||||
|
*result = default_value;
|
||||||
|
return Status::OK();
|
||||||
|
} else if (params_count == 1) {
|
||||||
|
*result = context.params.at(name).at(0);
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Expected a single '", name,
|
||||||
|
"' parameter, but found ", params_count,
|
||||||
|
" occurrences");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,7 +16,11 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
|
#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
|
||||||
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
|
#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -26,18 +30,73 @@ namespace graph_transforms {
|
|||||||
void MapNamesToNodes(const GraphDef& graph_def,
|
void MapNamesToNodes(const GraphDef& graph_def,
|
||||||
std::map<string, const NodeDef*>* result);
|
std::map<string, const NodeDef*>* result);
|
||||||
|
|
||||||
|
// For every node in the graph create a list of the nodes that use it as an
|
||||||
|
// input.
|
||||||
|
void MapNodesToOutputs(const GraphDef& graph_def,
|
||||||
|
std::map<string, std::vector<const NodeDef*>>* result);
|
||||||
|
|
||||||
// NodeDef input strings can contain other information besides the name of an
|
// NodeDef input strings can contain other information besides the name of an
|
||||||
// input node. These include:
|
// input node. These include:
|
||||||
// - Optional '^' prefix, indicating this is a control edge.
|
// - Optional '^' prefix, indicating this is a control edge.
|
||||||
// - The required name of the input node.
|
// - The required name of the input node.
|
||||||
// - Option ':<number>' suffix, showing which output of the node to use.
|
// - Optional ':<number>' suffix, showing which output of the node to use.
|
||||||
// This function takes a raw string, and breaks it into those component parts.
|
// This function takes a raw string, and breaks it into those component parts.
|
||||||
void NodeNamePartsFromInput(string input_name, string* prefix,
|
// The rules for inputs in function libraries are a bit more complex, and
|
||||||
|
// aren't handled by this routine.
|
||||||
|
void NodeNamePartsFromInput(const string& input_name, string* prefix,
|
||||||
string* node_name, string* suffix);
|
string* node_name, string* suffix);
|
||||||
|
|
||||||
|
// Adds a ':0' port to any inputs with no suffix, to make comparisons easier.
|
||||||
|
string CanonicalInputName(const string& input_name);
|
||||||
|
|
||||||
// Convenience function to strip the optional prefix and suffix components from
|
// Convenience function to strip the optional prefix and suffix components from
|
||||||
// a string pulled from a NodeDef input, and return the plain node name.
|
// a string pulled from a NodeDef input, and return the plain node name.
|
||||||
string NodeNameFromInput(string input_name);
|
string NodeNameFromInput(const string& input_name);
|
||||||
|
|
||||||
|
// Returns a stable hash for the contents of the NodeDef, so that equivalent
|
||||||
|
// nodes should have equal hashes.
|
||||||
|
uint64 HashNodeDef(const NodeDef& node);
|
||||||
|
|
||||||
|
// Adds the given node name to the end of the node's inputs.
|
||||||
|
void AddNodeInput(const string& input_name, NodeDef* node);
|
||||||
|
|
||||||
|
// Copies an attribute from one NodeDef to another.
|
||||||
|
void CopyNodeAttr(const NodeDef& source, const string& source_key,
|
||||||
|
const string& dest_key, NodeDef* dest);
|
||||||
|
|
||||||
|
// Inserts a value into a NodeDef's map of attributes.
|
||||||
|
// This is a bit different than AddNodeAttr in node_def_util.h because it
|
||||||
|
// overwrites any existing attributes with the same key.
|
||||||
|
template <class T>
|
||||||
|
inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
|
||||||
|
AttrValue attr_value;
|
||||||
|
SetAttrValue(value, &attr_value);
|
||||||
|
auto* attr_map = node->mutable_attr();
|
||||||
|
(*attr_map)[key] = attr_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
inline void SetNodeTensorAttr(const string& key, const Tensor& tensor,
|
||||||
|
NodeDef* node) {
|
||||||
|
TensorProto tensor_proto;
|
||||||
|
tensor.AsProtoTensorContent(&tensor_proto);
|
||||||
|
SetNodeAttr(key, tensor_proto, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inserts a Tensor into the specified attribute of a NodeDef.
|
||||||
|
template <class T>
|
||||||
|
inline void SetNodeTensorAttr(const string& key, const TensorShape& shape,
|
||||||
|
const std::vector<T>& values, NodeDef* node) {
|
||||||
|
const DataType dtype = DataTypeToEnum<T>::v();
|
||||||
|
CHECK_EQ(shape.num_elements(), values.size());
|
||||||
|
Tensor tensor(dtype, shape);
|
||||||
|
T* dest_data = tensor.flat<T>().data();
|
||||||
|
std::copy_n(values.data(), values.size(), dest_data);
|
||||||
|
SetNodeTensorAttr<T>(key, tensor, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves a tensor value from a NodeDef attribute.
|
||||||
|
Tensor GetNodeTensorAttr(const NodeDef& node, const string& key);
|
||||||
|
|
||||||
// Creates a copy of the input GraphDef, but only containing the nodes where the
|
// Creates a copy of the input GraphDef, but only containing the nodes where the
|
||||||
// supplied selector function returned true.
|
// supplied selector function returned true.
|
||||||
@ -51,6 +110,144 @@ void RemoveAttributes(const GraphDef& input_graph_def,
|
|||||||
const std::vector<string>& attributes,
|
const std::vector<string>& attributes,
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
// For a lot of replacement and matching operations it's useful to have the
|
||||||
|
// nodes processed in a controlled order, so this does a topological sort to
|
||||||
|
// ensure that nodes always appear in the GraphDef.node list after their inputs.
|
||||||
|
Status SortByExecutionOrder(const GraphDef& input_graph_def,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
// Finds inputs that refer to nodes that are not in the graph.
|
||||||
|
void FindInvalidInputs(const GraphDef& graph_def,
|
||||||
|
std::vector<std::pair<string, string>>* invalid_inputs);
|
||||||
|
|
||||||
|
// Returns a descriptive error status if there are problems spotted with the
|
||||||
|
// graph.
|
||||||
|
Status IsGraphValid(const GraphDef& graph_def);
|
||||||
|
|
||||||
|
// This is used to spot particular subgraphs in a larger model. To use it,
|
||||||
|
// create a pattern like:
|
||||||
|
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
||||||
|
// This defines a subgraph where a Conv2D has a ResizeBilinear input, which
|
||||||
|
// pulls from a MirrorPad op.
|
||||||
|
// Regular expressions aren't supported for the op names, but you can use "*" to
|
||||||
|
// match any op. You can also use | as a separator to match multiple op names,
|
||||||
|
// like "Reshape|Concat|Conv2D".
|
||||||
|
struct OpTypePattern {
|
||||||
|
string op;
|
||||||
|
std::vector<OpTypePattern> inputs;
|
||||||
|
string DebugString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns a sub-graph of nodes that match a pattern.
|
||||||
|
struct NodeMatch {
|
||||||
|
NodeMatch() : node() {}
|
||||||
|
NodeDef node;
|
||||||
|
std::vector<NodeMatch> inputs;
|
||||||
|
string DebugString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility class to spot subgraphs matching particular patterns.
|
||||||
|
class GraphMatcher {
|
||||||
|
public:
|
||||||
|
GraphMatcher(const GraphDef& graph_def);
|
||||||
|
|
||||||
|
// Sorts the input nodes into execution order, and then skips any previously
|
||||||
|
// matches so that no node appears in more than one match. The NodeDef
|
||||||
|
// pointers contained in the results are owned by the GraphMatcher object, and
|
||||||
|
// so will be invalid after its lifetime.
|
||||||
|
Status GetOpTypeMatches(const OpTypePattern& pattern,
|
||||||
|
std::vector<NodeMatch>* matches);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern,
|
||||||
|
const std::set<string>& previously_matched_nodes,
|
||||||
|
NodeMatch* match);
|
||||||
|
|
||||||
|
GraphDef graph_def_;
|
||||||
|
std::map<string, const NodeDef*> node_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReplaceMatchingOpTypesOptions {
|
||||||
|
// Whether to raise an error if the graph is left with dangling inputs. If you
|
||||||
|
// enable this option, you must fix inconsistencies in a later pass.
|
||||||
|
bool allow_inconsistencies;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Replaces all of the matching sub-graphs with new ops. This calls into the
|
||||||
|
// given function, and expects to receive a set of new nodes to replace each
|
||||||
|
// matched sub-graph. It has some logic to protect the integrity of the
|
||||||
|
// resulting graph, for example making sure that nodes needed by other nodes
|
||||||
|
// outside the sub-graph aren't removed. These are passed in as the set of
|
||||||
|
// outputs, and nodes with the same names must be added to the new nodes
|
||||||
|
// produced by the replacement function. Many of these checks can be disabled
|
||||||
|
// by setting allow_inconsistencies to true in the options, but then it's the
|
||||||
|
// caller's responsibility to patch up any problems before passing on the graph
|
||||||
|
// to others. There's more comprehensive usage documentation in the README.
|
||||||
|
Status ReplaceMatchingOpTypes(
|
||||||
|
const GraphDef& input_graph_def, const OpTypePattern& pattern,
|
||||||
|
const std::function<Status(const NodeMatch&, const std::set<string>&,
|
||||||
|
const std::set<string>&, std::vector<NodeDef>*)>&
|
||||||
|
node_generator,
|
||||||
|
const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
// Returns a list of the unique nodes found in this match.
|
||||||
|
void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
|
||||||
|
|
||||||
|
// Changes all input references to a particular node name.
|
||||||
|
Status RenameNodeInputs(const GraphDef& input_graph_def,
|
||||||
|
const std::map<string, string>& inputs_to_rename,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
// Utility function that copies all the nodes found in a match into the
|
||||||
|
// new_nodes list. This is useful in replacement functions when you decide to
|
||||||
|
// leave the original matched subgraph untouched and make no changes.
|
||||||
|
void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes);
|
||||||
|
|
||||||
|
// Holds information that's needed for transform functions.
|
||||||
|
typedef std::map<string, std::vector<string>> TransformFuncParameters;
|
||||||
|
struct TransformFuncContext {
|
||||||
|
std::vector<string> input_names;
|
||||||
|
std::vector<string> output_names;
|
||||||
|
TransformFuncParameters params;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns how many occurrences of the given parameter are present.
|
||||||
|
int CountParameters(const TransformFuncContext& context, const string& name);
|
||||||
|
|
||||||
|
// Gets a simple occurrence of a parameter, using a default if it isn't present.
|
||||||
|
Status GetExactlyOneParameter(const TransformFuncContext& context,
|
||||||
|
const string& name, const string& default_value,
|
||||||
|
string* result);
|
||||||
|
|
||||||
|
// This is the function API for all graph transformations, taking an input
|
||||||
|
// GraphDef and other arguments, and returning a transformed GraphDef.
|
||||||
|
typedef std::function<Status(const GraphDef&,
|
||||||
|
const TransformFuncContext& context, GraphDef*)>
|
||||||
|
TransformFunc;
|
||||||
|
|
||||||
|
// To add a new graph transform function, call the macro:
|
||||||
|
// REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
|
||||||
|
// Under the hood this adds the function to the list of known transforms, so you
|
||||||
|
// just need to link in the .cc file with your registration call to have access
|
||||||
|
// to it through the command line tool.
|
||||||
|
// The rest of the machinery below is to enable that automagical registration.
|
||||||
|
typedef std::map<string, TransformFunc> TransformRegistry;
|
||||||
|
TransformRegistry* GetTransformRegistry();
|
||||||
|
class TransformRegistrar {
|
||||||
|
public:
|
||||||
|
TransformRegistrar(const string& name, TransformFunc transform_func) {
|
||||||
|
TransformRegistry* transform_registry = GetTransformRegistry();
|
||||||
|
(*transform_registry)[name] = transform_func;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#define REGISTER_GRAPH_TRANSFORM(name, func) \
|
||||||
|
REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func)
|
||||||
|
#define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \
|
||||||
|
REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)
|
||||||
|
#define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) \
|
||||||
|
static tensorflow::graph_transforms::TransformRegistrar \
|
||||||
|
registrar__body__##ctr##__object(name, func);
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -52,8 +52,8 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
std::map<string, const NodeDef*> node_map;
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
|
||||||
MapNamesToNodes(graph_def, &node_map);
|
MapNamesToNodes(graph_def, &node_map);
|
||||||
|
|
||||||
EXPECT_EQ(1, node_map.count("a"));
|
EXPECT_EQ(1, node_map.count("a"));
|
||||||
EXPECT_EQ(1, node_map.count("b"));
|
EXPECT_EQ(1, node_map.count("b"));
|
||||||
EXPECT_EQ(1, node_map.count("add"));
|
EXPECT_EQ(1, node_map.count("add"));
|
||||||
@ -62,6 +62,52 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
EXPECT_EQ(0, node_map.count("no_such_node"));
|
EXPECT_EQ(0, node_map.count("no_such_node"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestMapNodesToOutputs() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 100;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, b_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
std::map<string, std::vector<const NodeDef*>> outputs_map;
|
||||||
|
MapNodesToOutputs(graph_def, &outputs_map);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs_map.count("a"));
|
||||||
|
EXPECT_EQ(1, outputs_map["a"].size());
|
||||||
|
EXPECT_EQ("add", outputs_map["a"][0]->name());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs_map.count("b"));
|
||||||
|
EXPECT_EQ(1, outputs_map["b"].size());
|
||||||
|
EXPECT_EQ("add", outputs_map["b"][0]->name());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs_map.count("add"));
|
||||||
|
EXPECT_EQ(1, outputs_map["add"].size());
|
||||||
|
EXPECT_EQ("output", outputs_map["add"][0]->name());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, outputs_map.count("placeholder"));
|
||||||
|
EXPECT_EQ(1, outputs_map["placeholder"].size());
|
||||||
|
EXPECT_EQ("output", outputs_map["placeholder"][0]->name());
|
||||||
|
|
||||||
|
EXPECT_EQ(0, outputs_map.count("output"));
|
||||||
|
EXPECT_EQ(0, outputs_map.count("no_such_node"));
|
||||||
|
}
|
||||||
|
|
||||||
void TestNodeNamePartsFromInput() {
|
void TestNodeNamePartsFromInput() {
|
||||||
string prefix;
|
string prefix;
|
||||||
string node_name;
|
string node_name;
|
||||||
@ -101,6 +147,75 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42"));
|
EXPECT_EQ("node_name", NodeNameFromInput("^node_name:42"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestCanonicalInputName() {
|
||||||
|
EXPECT_EQ("node_name:0", CanonicalInputName("node_name"));
|
||||||
|
EXPECT_EQ("node_name:0", CanonicalInputName("node_name:0"));
|
||||||
|
EXPECT_EQ("^node_name:0", CanonicalInputName("^node_name"));
|
||||||
|
EXPECT_EQ("^node_name:42", CanonicalInputName("^node_name:42"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAddNodeInput() {
|
||||||
|
NodeDef node;
|
||||||
|
AddNodeInput("foo", &node);
|
||||||
|
EXPECT_EQ("foo", node.input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCopyNodeAttr() {
|
||||||
|
NodeDef node;
|
||||||
|
auto mutable_attr = node.mutable_attr();
|
||||||
|
(*mutable_attr)["foo"].set_i(3);
|
||||||
|
|
||||||
|
NodeDef copied_node;
|
||||||
|
CopyNodeAttr(node, "foo", "bar", &copied_node);
|
||||||
|
EXPECT_EQ(3, copied_node.attr().at("bar").i());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSetNodeAttr() {
|
||||||
|
NodeDef node;
|
||||||
|
int32 value_i = 32;
|
||||||
|
SetNodeAttr("foo", value_i, &node);
|
||||||
|
EXPECT_EQ(32, node.attr().at("foo").i());
|
||||||
|
string value_s = "some_value";
|
||||||
|
SetNodeAttr("bar", value_s, &node);
|
||||||
|
EXPECT_EQ("some_value", node.attr().at("bar").s());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSetNodeTensorAttr() {
|
||||||
|
NodeDef node;
|
||||||
|
SetNodeTensorAttr<int32>("foo", {3, 1}, {1, 2, 3}, &node);
|
||||||
|
TensorProto tensor_proto = node.attr().at("foo").tensor();
|
||||||
|
Tensor tensor;
|
||||||
|
CHECK(tensor.FromProto(tensor_proto));
|
||||||
|
EXPECT_EQ(DT_INT32, tensor.dtype());
|
||||||
|
EXPECT_EQ(3, tensor.shape().dim_size(0));
|
||||||
|
EXPECT_EQ(1, tensor.shape().dim_size(1));
|
||||||
|
EXPECT_EQ(1, tensor.flat<int32>()(0));
|
||||||
|
EXPECT_EQ(2, tensor.flat<int32>()(1));
|
||||||
|
EXPECT_EQ(3, tensor.flat<int32>()(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSetNodeTensorAttrWithTensor() {
|
||||||
|
NodeDef node;
|
||||||
|
Tensor input_tensor(DT_INT32, {4, 5});
|
||||||
|
test::FillIota<int32>(&input_tensor, 1);
|
||||||
|
SetNodeTensorAttr<int32>("foo", input_tensor, &node);
|
||||||
|
TensorProto tensor_proto = node.attr().at("foo").tensor();
|
||||||
|
Tensor tensor;
|
||||||
|
CHECK(tensor.FromProto(tensor_proto));
|
||||||
|
test::ExpectTensorEqual<int32>(input_tensor, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetNodeTensorAttr() {
|
||||||
|
NodeDef node;
|
||||||
|
Tensor input_tensor(DT_INT32, {4, 5});
|
||||||
|
test::FillIota<int32>(&input_tensor, 1);
|
||||||
|
TensorProto tensor_proto;
|
||||||
|
input_tensor.AsProtoTensorContent(&tensor_proto);
|
||||||
|
SetNodeAttr("foo", tensor_proto, &node);
|
||||||
|
Tensor result = GetNodeTensorAttr(node, "foo");
|
||||||
|
test::ExpectTensorEqual<int32>(input_tensor, result);
|
||||||
|
}
|
||||||
|
|
||||||
void TestFilterGraphDef() {
|
void TestFilterGraphDef() {
|
||||||
auto root = tensorflow::Scope::NewRootScope();
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
@ -160,19 +275,679 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
EXPECT_EQ(nullptr,
|
EXPECT_EQ(nullptr,
|
||||||
tensorflow::AttrSlice(*removed_placeholder).Find("dtype"));
|
tensorflow::AttrSlice(*removed_placeholder).Find("dtype"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestGetOpTypeMatches() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 100;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, b_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphMatcher matcher(graph_def);
|
||||||
|
|
||||||
|
std::vector<NodeMatch> const_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Const"}, &const_matches));
|
||||||
|
EXPECT_EQ(2, const_matches.size());
|
||||||
|
for (const NodeMatch& match : const_matches) {
|
||||||
|
EXPECT_EQ("Const", match.node.op());
|
||||||
|
EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
|
||||||
|
<< "match.node.name()=" << match.node.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<NodeMatch> add_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add"}, &add_matches));
|
||||||
|
EXPECT_EQ(1, add_matches.size());
|
||||||
|
EXPECT_EQ("Add", add_matches[0].node.op());
|
||||||
|
EXPECT_EQ("add", add_matches[0].node.name());
|
||||||
|
|
||||||
|
std::vector<NodeMatch> add_child_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
|
||||||
|
&add_child_matches));
|
||||||
|
EXPECT_EQ(1, add_child_matches.size());
|
||||||
|
EXPECT_EQ("Add", add_child_matches[0].node.op());
|
||||||
|
EXPECT_EQ("add", add_child_matches[0].node.name());
|
||||||
|
EXPECT_EQ(2, add_child_matches[0].inputs.size());
|
||||||
|
for (const NodeMatch& match : add_child_matches[0].inputs) {
|
||||||
|
EXPECT_EQ("Const", match.node.op());
|
||||||
|
EXPECT_TRUE(("a" == match.node.name()) || ("b" == match.node.name()))
|
||||||
|
<< "match.node.name()=" << match.node.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<NodeMatch> no_such_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches({"NoSuch"}, &no_such_matches));
|
||||||
|
EXPECT_EQ(0, no_such_matches.size());
|
||||||
|
|
||||||
|
std::vector<NodeMatch> all_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches(
|
||||||
|
{"Mul", {{"Add", {{"Const"}, {"Const"}}}, {"Placeholder"}}},
|
||||||
|
&all_matches));
|
||||||
|
EXPECT_EQ(1, all_matches.size());
|
||||||
|
EXPECT_EQ("Mul", all_matches[0].node.op());
|
||||||
|
EXPECT_EQ("output", all_matches[0].node.name());
|
||||||
|
EXPECT_EQ(2, all_matches[0].inputs.size());
|
||||||
|
EXPECT_EQ("Add", all_matches[0].inputs[0].node.op());
|
||||||
|
EXPECT_EQ("add", all_matches[0].inputs[0].node.name());
|
||||||
|
EXPECT_EQ(2, all_matches[0].inputs[0].inputs.size());
|
||||||
|
EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[0].node.op());
|
||||||
|
EXPECT_EQ("a", all_matches[0].inputs[0].inputs[0].node.name());
|
||||||
|
EXPECT_EQ(0, all_matches[0].inputs[0].inputs[0].inputs.size());
|
||||||
|
EXPECT_EQ("Const", all_matches[0].inputs[0].inputs[1].node.op());
|
||||||
|
EXPECT_EQ("b", all_matches[0].inputs[0].inputs[1].node.name());
|
||||||
|
EXPECT_EQ(0, all_matches[0].inputs[0].inputs[1].inputs.size());
|
||||||
|
EXPECT_EQ("Placeholder", all_matches[0].inputs[1].node.op());
|
||||||
|
EXPECT_EQ("placeholder", all_matches[0].inputs[1].node.name());
|
||||||
|
EXPECT_EQ(0, all_matches[0].inputs[1].inputs.size());
|
||||||
|
|
||||||
|
std::vector<NodeMatch> wildcard_matches;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
matcher.GetOpTypeMatches({"*", {{"*"}, {"*"}}}, &wildcard_matches));
|
||||||
|
EXPECT_EQ(1, wildcard_matches.size());
|
||||||
|
EXPECT_EQ("Add", wildcard_matches[0].node.op());
|
||||||
|
EXPECT_EQ("Const", wildcard_matches[0].inputs[0].node.op());
|
||||||
|
EXPECT_EQ("a", wildcard_matches[0].inputs[0].node.name());
|
||||||
|
EXPECT_EQ("Const", wildcard_matches[0].inputs[1].node.op());
|
||||||
|
EXPECT_EQ("b", wildcard_matches[0].inputs[1].node.name());
|
||||||
|
|
||||||
|
std::vector<NodeMatch> or_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add|Mul"}, &or_matches));
|
||||||
|
EXPECT_EQ(2, or_matches.size());
|
||||||
|
EXPECT_EQ("Add", or_matches[0].node.op());
|
||||||
|
EXPECT_EQ("add", or_matches[0].node.name());
|
||||||
|
EXPECT_EQ("Mul", or_matches[1].node.op());
|
||||||
|
EXPECT_EQ("output", or_matches[1].node.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetOpTypeMatchesDAG() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 100;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, a_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphMatcher matcher(graph_def);
|
||||||
|
|
||||||
|
std::vector<NodeMatch> add_matches;
|
||||||
|
TF_ASSERT_OK(matcher.GetOpTypeMatches({"Add", {{"Const"}, {"Const"}}},
|
||||||
|
&add_matches));
|
||||||
|
EXPECT_EQ(1, add_matches.size());
|
||||||
|
EXPECT_EQ("Add", add_matches[0].node.op());
|
||||||
|
EXPECT_EQ("add", add_matches[0].node.name());
|
||||||
|
EXPECT_EQ("Const", add_matches[0].inputs[0].node.op());
|
||||||
|
EXPECT_EQ("a", add_matches[0].inputs[0].node.name());
|
||||||
|
EXPECT_EQ("Const", add_matches[0].inputs[1].node.op());
|
||||||
|
EXPECT_EQ("a", add_matches[0].inputs[1].node.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestReplaceMatchingOpTypes() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, b_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphDef replaced_graph_def;
|
||||||
|
TF_ASSERT_OK(ReplaceMatchingOpTypes(
|
||||||
|
graph_def, {"*"},
|
||||||
|
[](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
NodeDef original_copy;
|
||||||
|
original_copy.CopyFrom(match.node);
|
||||||
|
const string original_name = match.node.name();
|
||||||
|
original_copy.set_name(original_name + "_before_identity");
|
||||||
|
new_nodes->push_back(original_copy);
|
||||||
|
|
||||||
|
NodeDef identity_node;
|
||||||
|
identity_node.set_op("Identity");
|
||||||
|
identity_node.set_name(original_name);
|
||||||
|
*(identity_node.mutable_input()->Add()) = original_copy.name();
|
||||||
|
new_nodes->push_back(identity_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{}, &replaced_graph_def));
|
||||||
|
|
||||||
|
EXPECT_EQ(10, replaced_graph_def.node_size());
|
||||||
|
for (const NodeDef& node : replaced_graph_def.node()) {
|
||||||
|
if (node.name() == "output") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("output_before_identity", node.input(0));
|
||||||
|
} else if (node.name() == "output_before_identity") {
|
||||||
|
EXPECT_EQ("Mul", node.op());
|
||||||
|
EXPECT_EQ("add", node.input(0));
|
||||||
|
EXPECT_EQ("placeholder", node.input(1));
|
||||||
|
} else if (node.name() == "placeholder") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("placeholder_before_identity", node.input(0));
|
||||||
|
} else if (node.name() == "placeholder_before_identity") {
|
||||||
|
EXPECT_EQ("Placeholder", node.op());
|
||||||
|
} else if (node.name() == "add") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("add_before_identity", node.input(0));
|
||||||
|
} else if (node.name() == "add_before_identity") {
|
||||||
|
EXPECT_EQ("Add", node.op());
|
||||||
|
EXPECT_EQ("a", node.input(0));
|
||||||
|
EXPECT_EQ("b", node.input(1));
|
||||||
|
} else if (node.name() == "a") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("a_before_identity", node.input(0));
|
||||||
|
} else if (node.name() == "a_before_identity") {
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
} else if (node.name() == "b") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("b_before_identity", node.input(0));
|
||||||
|
} else if (node.name() == "b_before_identity") {
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(true, false) << "Unexpected node name found: " << node.name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestMatchedNodesAsArray() {
|
||||||
|
NodeMatch fourth;
|
||||||
|
fourth.node.set_name("fourth");
|
||||||
|
|
||||||
|
NodeMatch second;
|
||||||
|
second.node.set_name("second");
|
||||||
|
second.inputs.push_back(fourth);
|
||||||
|
|
||||||
|
NodeMatch third;
|
||||||
|
third.node.set_name("third");
|
||||||
|
third.inputs.push_back(fourth);
|
||||||
|
|
||||||
|
NodeMatch first;
|
||||||
|
first.node.set_name("first");
|
||||||
|
first.inputs.push_back(second);
|
||||||
|
first.inputs.push_back(third);
|
||||||
|
|
||||||
|
std::vector<NodeDef> result;
|
||||||
|
MatchedNodesAsArray(first, &result);
|
||||||
|
|
||||||
|
EXPECT_EQ(4, result.size());
|
||||||
|
EXPECT_EQ("first", result[0].name());
|
||||||
|
EXPECT_EQ("second", result[1].name());
|
||||||
|
EXPECT_EQ("third", result[2].name());
|
||||||
|
EXPECT_EQ("fourth", result[3].name());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRenameNodeInputs() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, a_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphDef renamed_graph_def;
|
||||||
|
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"a", "b"}}, &renamed_graph_def));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
MapNamesToNodes(renamed_graph_def, &node_map);
|
||||||
|
EXPECT_EQ("b", node_map.at("add")->input(0));
|
||||||
|
EXPECT_EQ("b", node_map.at("add")->input(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRenameNodeInputsWithRedirects() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Tensor c_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&c_data, 1.0f);
|
||||||
|
Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, b_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphDef renamed_graph_def;
|
||||||
|
TF_ASSERT_OK(RenameNodeInputs(
|
||||||
|
graph_def, {{"a", "f"}, {"f", "e"}, {"e", "d"}, {"d", "c"}},
|
||||||
|
&renamed_graph_def));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
MapNamesToNodes(renamed_graph_def, &node_map);
|
||||||
|
EXPECT_EQ("c", node_map.at("add")->input(0));
|
||||||
|
EXPECT_EQ("b", node_map.at("add")->input(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRenameNodeInputsWithCycle() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
Tensor c_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&c_data, 1.0f);
|
||||||
|
Output c_const = Const(root.WithOpName("c"), Input::Initializer(c_data));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), a_const, b_const);
|
||||||
|
|
||||||
|
Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||||
|
|
||||||
|
Output mul = Mul(root.WithOpName("output"), add, placeholder);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphDef renamed_graph_def;
|
||||||
|
Status rename_status = RenameNodeInputs(graph_def, {{"a", "d"}, {"d", "a"}},
|
||||||
|
&renamed_graph_def);
|
||||||
|
EXPECT_FALSE(rename_status.ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRenameNodeInputsWithWildcard() {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
|
||||||
|
QuantizeV2 quantize_a(root.WithOpName("quantize_a"), a_const, a_const,
|
||||||
|
a_const, DT_QUINT8,
|
||||||
|
QuantizeV2::Attrs().Mode("MIN_FIRST"));
|
||||||
|
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
|
||||||
|
|
||||||
|
QuantizeV2 quantize_b(root.WithOpName("quantize_b"), b_const, b_const,
|
||||||
|
b_const, DT_QUINT8,
|
||||||
|
QuantizeV2::Attrs().Mode("MIN_FIRST"));
|
||||||
|
|
||||||
|
Output add = Add(root.WithOpName("add"), quantize_a.output_min,
|
||||||
|
quantize_a.output_max);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
GraphDef renamed_graph_def;
|
||||||
|
TF_ASSERT_OK(RenameNodeInputs(graph_def, {{"quantize_a:*", "quantize_b"}},
|
||||||
|
&renamed_graph_def));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_map;
|
||||||
|
MapNamesToNodes(renamed_graph_def, &node_map);
|
||||||
|
EXPECT_EQ("quantize_b:1", node_map.at("add")->input(0));
|
||||||
|
EXPECT_EQ("quantize_b:2", node_map.at("add")->input(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestFindInvalidInputs() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node = graph_def.mutable_node()->Add();
|
||||||
|
mul_node->set_op("Mul");
|
||||||
|
mul_node->set_name("mul_node");
|
||||||
|
*(mul_node->mutable_input()->Add()) = "add_node1";
|
||||||
|
*(mul_node->mutable_input()->Add()) = "add_node2:0";
|
||||||
|
*(mul_node->mutable_input()->Add()) = "^const_node1:0";
|
||||||
|
|
||||||
|
NodeDef* add_node1 = graph_def.mutable_node()->Add();
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
*(add_node1->mutable_input()->Add()) = "missing_input1";
|
||||||
|
*(add_node1->mutable_input()->Add()) = "const_node1:0";
|
||||||
|
*(add_node1->mutable_input()->Add()) = "missing_input2";
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.mutable_node()->Add();
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
*(add_node2->mutable_input()->Add()) = "missing_input3";
|
||||||
|
*(add_node2->mutable_input()->Add()) = "const_node1:0";
|
||||||
|
*(add_node2->mutable_input()->Add()) = "^const_node2";
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.mutable_node()->Add();
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.mutable_node()->Add();
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
|
||||||
|
std::vector<std::pair<string, string>> invalid_inputs;
|
||||||
|
FindInvalidInputs(graph_def, &invalid_inputs);
|
||||||
|
EXPECT_EQ(3, invalid_inputs.size());
|
||||||
|
for (const std::pair<string, string>& invalid_input : invalid_inputs) {
|
||||||
|
EXPECT_TRUE((invalid_input.first == "add_node1") ||
|
||||||
|
(invalid_input.first == "add_node2"));
|
||||||
|
if (invalid_input.first == "add_node1") {
|
||||||
|
EXPECT_TRUE((invalid_input.second == "missing_input1") ||
|
||||||
|
(invalid_input.second == "missing_input2"))
|
||||||
|
<< invalid_input.second;
|
||||||
|
} else if (invalid_input.first == "add_node2") {
|
||||||
|
EXPECT_EQ("missing_input3", invalid_input.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestIsGraphValid() {
|
||||||
|
GraphDef invalid_graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node = invalid_graph_def.mutable_node()->Add();
|
||||||
|
mul_node->set_op("Mul");
|
||||||
|
mul_node->set_name("mul_node");
|
||||||
|
*(mul_node->mutable_input()->Add()) = "add_node1";
|
||||||
|
*(mul_node->mutable_input()->Add()) = "add_node2:0";
|
||||||
|
*(mul_node->mutable_input()->Add()) = "^const_node1:0";
|
||||||
|
|
||||||
|
NodeDef* add_node1 = invalid_graph_def.mutable_node()->Add();
|
||||||
|
add_node1->set_op("Add");
|
||||||
|
add_node1->set_name("add_node1");
|
||||||
|
*(add_node1->mutable_input()->Add()) = "missing_input1";
|
||||||
|
*(add_node1->mutable_input()->Add()) = "const_node1:0";
|
||||||
|
*(add_node1->mutable_input()->Add()) = "missing_input2";
|
||||||
|
|
||||||
|
NodeDef* add_node2 = invalid_graph_def.mutable_node()->Add();
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
*(add_node2->mutable_input()->Add()) = "missing_input3";
|
||||||
|
*(add_node2->mutable_input()->Add()) = "const_node1:0";
|
||||||
|
*(add_node2->mutable_input()->Add()) = "^const_node2";
|
||||||
|
|
||||||
|
NodeDef* const_node1 = invalid_graph_def.mutable_node()->Add();
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = invalid_graph_def.mutable_node()->Add();
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
|
||||||
|
EXPECT_FALSE(IsGraphValid(invalid_graph_def).ok());
|
||||||
|
|
||||||
|
GraphDef valid_graph_def;
|
||||||
|
|
||||||
|
NodeDef* const_node3 = valid_graph_def.mutable_node()->Add();
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
const_node3->set_name("const_node2");
|
||||||
|
|
||||||
|
EXPECT_TRUE(IsGraphValid(valid_graph_def).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCopyOriginalMatch() {
|
||||||
|
NodeDef a;
|
||||||
|
a.set_op("Relu");
|
||||||
|
a.set_name("a");
|
||||||
|
AddNodeInput("b", &a);
|
||||||
|
|
||||||
|
NodeDef b;
|
||||||
|
b.set_op("Const");
|
||||||
|
b.set_name("b");
|
||||||
|
|
||||||
|
NodeMatch b_match;
|
||||||
|
b_match.node = b;
|
||||||
|
|
||||||
|
NodeMatch a_match;
|
||||||
|
a_match.node = a;
|
||||||
|
a_match.inputs.push_back(b_match);
|
||||||
|
|
||||||
|
std::vector<NodeDef> new_nodes;
|
||||||
|
CopyOriginalMatch(a_match, &new_nodes);
|
||||||
|
EXPECT_EQ(2, new_nodes.size());
|
||||||
|
EXPECT_EQ("a", new_nodes[0].name());
|
||||||
|
EXPECT_EQ("Relu", new_nodes[0].op());
|
||||||
|
EXPECT_EQ("b", new_nodes[1].name());
|
||||||
|
EXPECT_EQ("Const", new_nodes[1].op());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestHashNodeDef() {
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
auto a_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(a_root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
GraphDef a_graph_def;
|
||||||
|
TF_ASSERT_OK(a_root.ToGraphDef(&a_graph_def));
|
||||||
|
const NodeDef& a_node_def = a_graph_def.node(0);
|
||||||
|
|
||||||
|
auto b_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor b_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&b_data, 1.0f);
|
||||||
|
Output b_const = Const(b_root.WithOpName("a"), Input::Initializer(b_data));
|
||||||
|
GraphDef b_graph_def;
|
||||||
|
TF_ASSERT_OK(b_root.ToGraphDef(&b_graph_def));
|
||||||
|
const NodeDef& b_node_def = b_graph_def.node(0);
|
||||||
|
|
||||||
|
auto c_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor c_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&c_data, 2.0f);
|
||||||
|
Output c_const = Const(c_root.WithOpName("a"), Input::Initializer(c_data));
|
||||||
|
GraphDef c_graph_def;
|
||||||
|
TF_ASSERT_OK(c_root.ToGraphDef(&c_graph_def));
|
||||||
|
const NodeDef& c_node_def = c_graph_def.node(0);
|
||||||
|
|
||||||
|
auto d_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor d_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&d_data, 1.0f);
|
||||||
|
Output d_const = Const(d_root.WithOpName("d"), Input::Initializer(d_data));
|
||||||
|
GraphDef d_graph_def;
|
||||||
|
TF_ASSERT_OK(d_root.ToGraphDef(&d_graph_def));
|
||||||
|
const NodeDef& d_node_def = d_graph_def.node(0);
|
||||||
|
|
||||||
|
auto e_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor e_data(DT_INT32, TensorShape({width}));
|
||||||
|
test::FillIota<int32>(&e_data, 1);
|
||||||
|
Output e_const = Const(e_root.WithOpName("a"), Input::Initializer(e_data));
|
||||||
|
GraphDef e_graph_def;
|
||||||
|
TF_ASSERT_OK(e_root.ToGraphDef(&e_graph_def));
|
||||||
|
const NodeDef& e_node_def = e_graph_def.node(0);
|
||||||
|
|
||||||
|
auto f_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor f_data(DT_FLOAT, TensorShape({width - 1}));
|
||||||
|
test::FillIota<float>(&f_data, 1.0f);
|
||||||
|
Output f_const = Const(f_root.WithOpName("a"), Input::Initializer(f_data));
|
||||||
|
GraphDef f_graph_def;
|
||||||
|
TF_ASSERT_OK(f_root.ToGraphDef(&f_graph_def));
|
||||||
|
const NodeDef& f_node_def = f_graph_def.node(0);
|
||||||
|
|
||||||
|
auto g_root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor g_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&g_data, 1);
|
||||||
|
Output g_const = Const(g_root.WithOpName("a").WithDevice("some_device"),
|
||||||
|
Input::Initializer(g_data));
|
||||||
|
GraphDef g_graph_def;
|
||||||
|
TF_ASSERT_OK(g_root.ToGraphDef(&g_graph_def));
|
||||||
|
const NodeDef& g_node_def = g_graph_def.node(0);
|
||||||
|
|
||||||
|
NodeDef relu1_node_def;
|
||||||
|
relu1_node_def.set_op("Relu");
|
||||||
|
relu1_node_def.set_name("a");
|
||||||
|
relu1_node_def.add_input("foo");
|
||||||
|
|
||||||
|
NodeDef relu2_node_def;
|
||||||
|
relu2_node_def.set_op("Relu");
|
||||||
|
relu2_node_def.set_name("a");
|
||||||
|
relu2_node_def.add_input("bar");
|
||||||
|
|
||||||
|
EXPECT_EQ(HashNodeDef(a_node_def), HashNodeDef(b_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(c_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(d_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(e_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(f_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(g_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(a_node_def), HashNodeDef(relu1_node_def));
|
||||||
|
EXPECT_NE(HashNodeDef(relu1_node_def), HashNodeDef(relu2_node_def));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestCountParameters() {
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.params.insert({"foo", {"a", "b"}});
|
||||||
|
context.params.insert({"bar", {"c"}});
|
||||||
|
EXPECT_EQ(2, CountParameters(context, "foo"));
|
||||||
|
EXPECT_EQ(1, CountParameters(context, "bar"));
|
||||||
|
EXPECT_EQ(0, CountParameters(context, "not_present"));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetExactlyOneParameter() {
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.params.insert({"foo", {"a", "b"}});
|
||||||
|
context.params.insert({"bar", {"c"}});
|
||||||
|
string value;
|
||||||
|
TF_EXPECT_OK(GetExactlyOneParameter(context, "bar", "d", &value));
|
||||||
|
EXPECT_EQ("c", value);
|
||||||
|
EXPECT_FALSE(GetExactlyOneParameter(context, "foo", "d", &value).ok());
|
||||||
|
TF_EXPECT_OK(GetExactlyOneParameter(context, "not_present", "d", &value));
|
||||||
|
EXPECT_EQ("d", value);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestMapNodesToOutputs) { TestMapNodesToOutputs(); }
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) {
|
TEST_F(TransformUtilsTest, TestNodeNamePartsFromInput) {
|
||||||
TestNodeNamePartsFromInput();
|
TestNodeNamePartsFromInput();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestCanonicalInputName) { TestCanonicalInputName(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestAddNodeInput) { TestAddNodeInput(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestCopyNodeAttr) { TestCopyNodeAttr(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestSetNodeAttr) { TestSetNodeAttr(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestSetNodeTensorAttr) { TestSetNodeTensorAttr(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestSetNodeTensorAttrWithTensor) {
|
||||||
|
TestSetNodeTensorAttrWithTensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetNodeTensorAttr) { TestGetNodeTensorAttr(); }
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); }
|
TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); }
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); }
|
TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); }
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); }
|
TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetOpTypeMatches) { TestGetOpTypeMatches(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetOpTypeMatchesDAG) {
|
||||||
|
TestGetOpTypeMatchesDAG();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestReplaceMatchingOpTypes) {
|
||||||
|
TestReplaceMatchingOpTypes();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestMatchedNodesAsArray) {
|
||||||
|
TestMatchedNodesAsArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestRenameNodeInputs) { TestRenameNodeInputs(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithRedirects) {
|
||||||
|
TestRenameNodeInputsWithRedirects();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithCycle) {
|
||||||
|
TestRenameNodeInputsWithCycle();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestRenameNodeInputsWithWildcard) {
|
||||||
|
TestRenameNodeInputsWithWildcard();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestFindInvalidInputs) { TestFindInvalidInputs(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestIsGraphValid) { TestIsGraphValid(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestCopyOriginalMatch) { TestCopyOriginalMatch(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) {
|
||||||
|
TestGetExactlyOneParameter();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user