Add new transforms and compare_graph tool

Change: 144024968
This commit is contained in:
Pete Warden 2017-01-09 16:46:48 -08:00 committed by TensorFlower Gardener
parent bfa0fa079d
commit 80679a6a74
14 changed files with 686 additions and 72 deletions

View File

@ -43,6 +43,7 @@ tf_cc_test(
":transform_utils", ":transform_utils",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -55,6 +56,7 @@ tf_cc_test(
cc_library( cc_library(
name = "transforms_lib", name = "transforms_lib",
srcs = [ srcs = [
"add_default_attributes.cc",
"fold_batch_norms.cc", "fold_batch_norms.cc",
"fold_constants_lib.cc", "fold_constants_lib.cc",
"fold_old_batch_norms.cc", "fold_old_batch_norms.cc",
@ -68,6 +70,7 @@ cc_library(
"rename_attribute.cc", "rename_attribute.cc",
"rename_op.cc", "rename_op.cc",
"round_weights.cc", "round_weights.cc",
"set_device.cc",
"sort_by_execution_order.cc", "sort_by_execution_order.cc",
"strip_unused_nodes.cc", "strip_unused_nodes.cc",
], ],
@ -80,6 +83,7 @@ cc_library(
":transform_utils", ":transform_utils",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -93,6 +97,7 @@ tf_cc_test(
name = "transforms_test", name = "transforms_test",
size = "small", size = "small",
srcs = [ srcs = [
"add_default_attributes_test.cc",
"fold_batch_norms_test.cc", "fold_batch_norms_test.cc",
"fold_constants_test.cc", "fold_constants_test.cc",
"fold_old_batch_norms_test.cc", "fold_old_batch_norms_test.cc",
@ -106,6 +111,7 @@ tf_cc_test(
"rename_attribute_test.cc", "rename_attribute_test.cc",
"rename_op_test.cc", "rename_op_test.cc",
"round_weights_test.cc", "round_weights_test.cc",
"set_device_test.cc",
"sort_by_execution_order_test.cc", "sort_by_execution_order_test.cc",
"strip_unused_nodes_test.cc", "strip_unused_nodes_test.cc",
], ],
@ -209,3 +215,18 @@ cc_binary(
":summarize_graph_main_lib", ":summarize_graph_main_lib",
], ],
) )
cc_binary(
name = "compare_graphs",
srcs = ["compare_graphs.cc"],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":transform_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)

View File

@ -12,6 +12,7 @@
* [Shrinking File Size](#shrinking-file-size) * [Shrinking File Size](#shrinking-file-size)
* [Eight-bit Calculations](#eight-bit-calculations) * [Eight-bit Calculations](#eight-bit-calculations)
* [Transform Reference](#transform-reference) * [Transform Reference](#transform-reference)
* [add_default_attributes](#add_default_attributes)
* [fold_batch_norms](#fold_batch_norms) * [fold_batch_norms](#fold_batch_norms)
* [fold_constants](#fold_constants) * [fold_constants](#fold_constants)
* [fold_old_batch_norms](#fold_old_batch_norms) * [fold_old_batch_norms](#fold_old_batch_norms)
@ -26,6 +27,7 @@
* [rename_attribute](#rename_attribute) * [rename_attribute](#rename_attribute)
* [rename_op](#rename_op) * [rename_op](#rename_op)
* [round_weights](#round_weights) * [round_weights](#round_weights)
* [set_device](#set_device)
* [sort_by_execution_order](#sort_by_execution_order) * [sort_by_execution_order](#sort_by_execution_order)
* [strip_unused_nodes](#strip_unused_nodes) * [strip_unused_nodes](#strip_unused_nodes)
* [Writing Your Own Transforms](#writing-your-own-transforms) * [Writing Your Own Transforms](#writing-your-own-transforms)
@ -318,6 +320,18 @@ logged and the transform skipped. This is especially useful for optional
transforms where version errors or other unimportant problems may trigger an transforms where version errors or other unimportant problems may trigger an
error. error.
### add_default_attributes
Args: None
When attributes are added to ops in new versions of TensorFlow, they often have
defaults to ensure backwards compatible behavior with their original versions.
These defaults usually get added when the graph is loaded by the runtime, but if
your model is going to be processed outside of the main TensorFlow framework it
can be useful to run this update process as a transform. This process finds any
op attributes that are defined in the current TensorFlow list of ops but not
within the saved model, and sets them to the defined default for that attribute.
### fold_batch_norms ### fold_batch_norms
Args: None Args: None
@ -516,6 +530,22 @@ between the largest and smallest values present. This is useful when you'll be
deploying on mobile, and you want a model that will compress effectively. See deploying on mobile, and you want a model that will compress effectively. See
[shrinking file size](#shrinking-file-size) for more details. [shrinking file size](#shrinking-file-size) for more details.
### set_device
Args:
* device: What device to assign to ops.
* if_default: If this is true, only assign to ops with empty existing devices.
Updates nodes to use the specified device. A device is a way to tell the code
that executes the graph which piece of hardware it should run particular nodes
on. The right assignment to use may change between training and deployment, so
this transform (and [remove_device](#remove_device)) provide a way of updating
the placement. If the `is_default` parameter is set, then only ops that don't
have a device assigned already will be updated. This is mostly useful for
preprocessing of graphs for other stages that expect all ops to have an explicit
device assigned.
### sort_by_execution_order ### sort_by_execution_order
Args: None\ Args: None\
@ -844,23 +874,13 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps`
parameter: parameter:
```C++ ```C++
string num_steps_string; TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
TF_RETURN_IF_ERROR(
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
int32 num_steps;
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
return errors::InvalidArgument(
"Couldn't interpret the num_steps argument to round_weights as a "
"number:",
num_steps_string);
}
``` ```
Something to notice here is that you have to convert the string to an integer, If the conversion fails or the parameter occurs more than once the helper
and if the conversion fails you need to raise a meaningful error through the function will raise a meaningful error through the status result of the
status result of the transform. Also, we're using a helper function which raises transform. If the parameter isn't specified at all then the default will be
an error if the parameter is present multiple times, and uses a default if the used.
user hasn't specified it.
### Function Libraries ### Function Libraries

View File

@ -0,0 +1,40 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Sets any parameters not specified in a node to their defaults.
Status AddDefaultAttributes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
// Find all of the ops that are currently defined.
std::unique_ptr<FunctionLibraryDefinition> flib_def(
new FunctionLibraryDefinition(OpRegistry::Global(),
input_graph_def.library()));
// Works in-place, so copy over the original graph.
*output_graph_def = input_graph_def;
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(output_graph_def, *flib_def, 0));
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("add_default_attributes", AddDefaultAttributes);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,74 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status AddDefaultAttributes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class AddDefaultAttributesTest : public ::testing::Test {
protected:
void TestAddDefaultAttributes() {
GraphDef graph_def;
NodeDef* lrn_node1 = graph_def.add_node();
lrn_node1->set_name("lrn_node1");
lrn_node1->set_op("LRN");
NodeDef* lrn_node2 = graph_def.add_node();
lrn_node2->set_name("lrn_node2");
lrn_node2->set_op("LRN");
SetNodeAttr("depth_radius", 7, lrn_node2);
SetNodeAttr("bias", 2.0f, lrn_node2);
SetNodeAttr("alpha", 2.0f, lrn_node2);
SetNodeAttr("beta", 1.0f, lrn_node2);
GraphDef result;
TF_ASSERT_OK(AddDefaultAttributes(graph_def, {}, &result));
std::map<string, const NodeDef*> nodes;
MapNamesToNodes(result, &nodes);
EXPECT_EQ(5, nodes.at("lrn_node1")->attr().at("depth_radius").i());
EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("bias").f(), 1e-5f);
EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("alpha").f(), 1e-5f);
EXPECT_NEAR(0.5f, nodes.at("lrn_node1")->attr().at("beta").f(), 1e-5f);
EXPECT_EQ(7, nodes.at("lrn_node2")->attr().at("depth_radius").i());
EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("bias").f(), 1e-5f);
EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("alpha").f(), 1e-5f);
EXPECT_NEAR(1.0f, nodes.at("lrn_node2")->attr().at("beta").f(), 1e-5f);
}
};
TEST_F(AddDefaultAttributesTest, TestAddDefaultAttributes) {
TestAddDefaultAttributes();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,79 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Compares two TensorFlow graphs to see if their meaning is the same. This is a
// semantic comparison that's intended to show whether the graphs should produce
// the same results, and so ignores details like version numbers or node
// ordering that don't affect the output. To use it, run something like this:
//
// bazel build tensorflow/tools/graph_transforms:compare_graphs
// bazel-bin/tensorflow/tools/graph_transforms/compare_graphs a.pb b.pb
//
// The return value is 0 if the graphs are equal, 1 if they're different, and -1
// if there was a problem.
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
namespace {
int ParseFlagsAndCompareGraphs(int argc, char* argv[]) {
// We need to call this to set up global state for TensorFlow.
port::InitMain(argv[0], &argc, &argv);
if (argc != 3) {
LOG(ERROR) << "compare_graphs expects two file names as arguments";
return -1;
}
GraphDef a;
Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a);
if (!a_load_status.ok()) {
LOG(ERROR) << "Loading graph '" << argv[1] << "' failed with "
<< a_load_status.error_message();
return -1;
}
GraphDef b;
Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b);
if (!b_load_status.ok()) {
LOG(ERROR) << "Loading graph '" << argv[2] << "' failed with "
<< b_load_status.error_message();
return -1;
}
string diff;
if (EqualGraphDef(a, b, &diff)) {
std::cout << "Graphs are equal." << std::endl;
return 0;
} else {
std::cout << diff << std::endl;
return 1;
}
}
} // namespace
} // namespace graph_transforms
} // namespace tensorflow
int main(int argc, char* argv[]) {
return tensorflow::graph_transforms::ParseFlagsAndCompareGraphs(argc, argv);
}

View File

@ -40,10 +40,6 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
required_nodes.insert(output); required_nodes.insert(output);
} }
for (const string& required_node : required_nodes) {
LOG(INFO) << "required_node=" << required_node;
}
const string valid_chars = const string valid_chars =
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
const int64 chars_size = valid_chars.size(); const int64 chars_size = valid_chars.size();

View File

@ -157,22 +157,8 @@ Status ExtractRangeFromParams(const TransformFuncContext& context,
return errors::InvalidArgument("You must pass both ", min_name, " and ", return errors::InvalidArgument("You must pass both ", min_name, " and ",
max_name, " into quantize_nodes"); max_name, " into quantize_nodes");
} }
std::vector<string> min_strings = context.params.at(min_name); TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
std::vector<string> max_strings = context.params.at(max_name); TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
if ((min_strings.size() != 1) || (max_strings.size() != 1)) {
return errors::InvalidArgument("You must pass a single ", min_name,
" and single ", max_name,
" value into "
"quantize_nodes");
}
if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) {
return errors::InvalidArgument("Couldn't decode ", min_name,
" as a number: ", min_strings[0]);
}
if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) {
return errors::InvalidArgument("Couldn't decode ", max_name,
" as a number: ", max_strings[0]);
}
return Status::OK(); return Status::OK();
} }

View File

@ -33,16 +33,8 @@ namespace graph_transforms {
Status RoundWeights(const GraphDef& input_graph_def, Status RoundWeights(const GraphDef& input_graph_def,
const TransformFuncContext& context, const TransformFuncContext& context,
GraphDef* output_graph_def) { GraphDef* output_graph_def) {
string num_steps_string; int64 num_steps;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
int32 num_steps;
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
return errors::InvalidArgument(
"Couldn't interpret the num_steps argument to round_weights as a "
"number:",
num_steps_string);
}
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
input_graph_def, {"Const"}, input_graph_def, {"Const"},
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes, [num_steps](const NodeMatch& match, const std::set<string>& input_nodes,

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Sets the device field of ops in the graph.
Status SetDevice(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
string new_device;
TF_RETURN_IF_ERROR(context.GetOneStringParameter("device", "", &new_device));
bool if_default;
TF_RETURN_IF_ERROR(
context.GetOneBoolParameter("if_default", false, &if_default));
output_graph_def->Clear();
for (const NodeDef& node : input_graph_def.node()) {
NodeDef* new_node = output_graph_def->mutable_node()->Add();
new_node->CopyFrom(node);
if (!if_default || (node.device() == "")) {
new_node->set_device(new_device);
}
}
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("set_device", SetDevice);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,127 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status SetDevice(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
namespace {
GraphDef CreateDeviceGraph() {
GraphDef graph_def;
NodeDef* mul_node1 = graph_def.add_node();
mul_node1->set_name("mul_node1");
mul_node1->set_op("Mul");
mul_node1->set_device("/device:CPU:0");
mul_node1->add_input("add_node2");
mul_node1->add_input("add_node3");
NodeDef* add_node2 = graph_def.add_node();
add_node2->set_name("add_node2");
add_node2->set_op("Add");
add_node2->add_input("const_node1");
add_node2->add_input("const_node2");
add_node2->set_device("/device:GPU:1");
NodeDef* add_node3 = graph_def.add_node();
add_node3->set_name("add_node3");
add_node3->set_op("Add");
add_node3->add_input("const_node1");
add_node3->add_input("const_node3");
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* add_node4 = graph_def.add_node();
add_node4->set_name("add_node4");
add_node4->set_op("Add");
add_node4->add_input("add_node2");
add_node4->add_input("add_node3");
return graph_def;
}
} // namespace
TEST(SetDeviceTest, TestSetDevice) {
GraphDef graph_def = CreateDeviceGraph();
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"mul_node1"};
context.params.insert(std::pair<string, std::vector<string>>(
{"device", {string("/device:CPU:0")}}));
TF_ASSERT_OK(SetDevice(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node2")->device());
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node3")->device());
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node1")->device());
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node2")->device());
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node3")->device());
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node4")->device());
}
TEST(SetDeviceTest, TestSetDeviceIfDefault) {
GraphDef graph_def = CreateDeviceGraph();
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"mul_node1"};
context.params.insert(std::pair<string, std::vector<string>>(
{"device", {string("/device:GPU:0")}}));
context.params.insert(
std::pair<string, std::vector<string>>({"if_default", {string("true")}}));
TF_ASSERT_OK(SetDevice(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
EXPECT_EQ("/device:GPU:1", node_lookup.at("add_node2")->device());
EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node3")->device());
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node1")->device());
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node2")->device());
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node3")->device());
EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node4")->device());
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -129,12 +129,15 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
string inputs_string = ""; string inputs_string = "";
string outputs_string = ""; string outputs_string = "";
string transforms_string = ""; string transforms_string = "";
bool output_as_text = false;
std::vector<Flag> flag_list = { std::vector<Flag> flag_list = {
Flag("in_graph", &in_graph, "input graph file name"), Flag("in_graph", &in_graph, "input graph file name"),
Flag("out_graph", &out_graph, "output graph file name"), Flag("out_graph", &out_graph, "output graph file name"),
Flag("inputs", &inputs_string, "inputs"), Flag("inputs", &inputs_string, "inputs"),
Flag("outputs", &outputs_string, "outputs"), Flag("outputs", &outputs_string, "outputs"),
Flag("transforms", &transforms_string, "list of transforms"), Flag("transforms", &transforms_string, "list of transforms"),
Flag("output_as_text", &output_as_text,
"whether to write the graph in text protobuf format"),
}; };
string usage = Flags::Usage(argv[0], flag_list); string usage = Flags::Usage(argv[0], flag_list);
usage += "\nTransforms are:\n"; usage += "\nTransforms are:\n";
@ -185,7 +188,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
} }
GraphDef graph_def; GraphDef graph_def;
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def); Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
if (!load_status.ok()) { if (!load_status.ok()) {
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with " LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
<< load_status.error_message(); << load_status.error_message();
@ -202,7 +205,12 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
return -1; return -1;
} }
Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def); Status save_status;
if (output_as_text) {
save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
} else {
save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
}
if (!save_status.ok()) { if (!save_status.ok()) {
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with " LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
<< save_status.error_message(); << save_status.error_message();

View File

@ -583,23 +583,45 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
return Status::OK(); return Status::OK();
} }
int CountParameters(const TransformFuncContext& context, const string& name) { Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
if (context.params.count(name)) { string file_data;
return context.params.at(name).size(); Status load_file_status =
ReadFileToString(Env::Default(), file_name, &file_data);
if (!load_file_status.ok()) {
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
return load_file_status;
}
// Try to load in binary format first, and then try ascii if that fails.
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
if (!load_status.ok()) {
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
load_status = Status::OK();
} else {
errors::AppendToMessage(&load_status,
" (both text and binary parsing failed for file ",
file_name, ")");
}
}
return load_status;
}
int TransformFuncContext::CountParameters(const string& name) const {
if (params.count(name)) {
return params.at(name).size();
} else { } else {
return 0; return 0;
} }
} }
Status GetExactlyOneParameter(const TransformFuncContext& context, Status TransformFuncContext::GetOneStringParameter(const string& name,
const string& name, const string& default_value, const string& default_value,
string* result) { string* result) const {
const int params_count = CountParameters(context, name); const int params_count = CountParameters(name);
if (params_count == 0) { if (params_count == 0) {
*result = default_value; *result = default_value;
return Status::OK(); return Status::OK();
} else if (params_count == 1) { } else if (params_count == 1) {
*result = context.params.at(name).at(0); *result = params.at(name).at(0);
return Status::OK(); return Status::OK();
} else { } else {
return errors::InvalidArgument("Expected a single '", name, return errors::InvalidArgument("Expected a single '", name,
@ -608,5 +630,62 @@ Status GetExactlyOneParameter(const TransformFuncContext& context,
} }
} }
Status TransformFuncContext::GetOneIntParameter(const string& name,
int64 default_value,
int64* result) const {
const int params_count = CountParameters(name);
if (params_count == 0) {
*result = default_value;
return Status::OK();
}
string string_value;
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
if (!strings::safe_strto64(StringPiece(string_value), result)) {
return errors::InvalidArgument("Couldn't interpret the ", name,
" argument as a number:", string_value);
}
return Status::OK();
}
Status TransformFuncContext::GetOneFloatParameter(const string& name,
float default_value,
float* result) const {
const int params_count = CountParameters(name);
if (params_count == 0) {
*result = default_value;
return Status::OK();
}
string string_value;
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
if (!strings::safe_strtof(string_value.c_str(), result)) {
return errors::InvalidArgument(
"Couldn't interpret the ", name,
" argument as a float number:", string_value);
}
return Status::OK();
}
Status TransformFuncContext::GetOneBoolParameter(const string& name,
bool default_value,
bool* result) const {
const int params_count = CountParameters(name);
if (params_count == 0) {
*result = default_value;
return Status::OK();
}
string string_value;
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
if (string_value == "true" || string_value == "1") {
*result = true;
} else if (string_value == "false" || string_value == "0") {
*result = false;
} else {
return errors::InvalidArgument("Couldn't interpret the ", name,
" argument as a boolean:", string_value,
" (expected true, false, 0 or 1)");
}
return Status::OK();
}
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow

View File

@ -128,6 +128,10 @@ Status IsGraphValid(const GraphDef& graph_def);
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
DataTypeVector* outputs); DataTypeVector* outputs);
// First tries to load the file as a text protobuf, if that fails tries to parse
// it as a binary protobuf, and returns an error if both fail.
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
// This is used to spot particular subgraphs in a larger model. To use it, // This is used to spot particular subgraphs in a larger model. To use it,
// create a pattern like: // create a pattern like:
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
@ -213,16 +217,33 @@ struct TransformFuncContext {
std::vector<string> input_names; std::vector<string> input_names;
std::vector<string> output_names; std::vector<string> output_names;
TransformFuncParameters params; TransformFuncParameters params;
// Returns how many occurrences of the given parameter are present.
int CountParameters(const string& name) const;
// Gets a single instance of a parameter, using a default if it's not present.
Status GetOneStringParameter(const string& name, const string& default_value,
string* result) const;
// Gets a single occurrence of a parameter as an integer, falling back to a
// default if it isn't present and returning an error if it isn't convertible
// to a number.
Status GetOneIntParameter(const string& name, int64 default_value,
int64* result) const;
// Gets a single occurrence of a parameter as a floating point number, falling
// back to a default if it isn't present and returning an error if it isn't
// convertible to a number.
Status GetOneFloatParameter(const string& name, float default_value,
float* result) const;
// Gets a single occurrence of a parameter as a boolean, falling back to a
// default if it isn't present and returning an error if it's not one of
// "true", "1", "false", or "0".
Status GetOneBoolParameter(const string& name, bool default_value,
bool* result) const;
}; };
// Returns how many occurrences of the given parameter are present.
int CountParameters(const TransformFuncContext& context, const string& name);
// Gets a simple occurrence of a parameter, using a default if it isn't present.
Status GetExactlyOneParameter(const TransformFuncContext& context,
const string& name, const string& default_value,
string* result);
// This is the function API for all graph transformations, taking an input // This is the function API for all graph transformations, taking an input
// GraphDef and other arguments, and returning a transformed GraphDef. // GraphDef and other arguments, and returning a transformed GraphDef.
typedef std::function<Status(const GraphDef&, typedef std::function<Status(const GraphDef&,

View File

@ -19,7 +19,9 @@ limitations under the License.
#include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/equal_graph_def.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
@ -924,22 +926,131 @@ class TransformUtilsTest : public ::testing::Test {
TransformFuncContext context; TransformFuncContext context;
context.params.insert({"foo", {"a", "b"}}); context.params.insert({"foo", {"a", "b"}});
context.params.insert({"bar", {"c"}}); context.params.insert({"bar", {"c"}});
EXPECT_EQ(2, CountParameters(context, "foo")); EXPECT_EQ(2, context.CountParameters("foo"));
EXPECT_EQ(1, CountParameters(context, "bar")); EXPECT_EQ(1, context.CountParameters("bar"));
EXPECT_EQ(0, CountParameters(context, "not_present")); EXPECT_EQ(0, context.CountParameters("not_present"));
} }
void TestGetExactlyOneParameter() { void TestGetOneStringParameter() {
TransformFuncContext context; TransformFuncContext context;
context.params.insert({"foo", {"a", "b"}}); context.params.insert({"foo", {"a", "b"}});
context.params.insert({"bar", {"c"}}); context.params.insert({"bar", {"c"}});
string value; string value;
TF_EXPECT_OK(GetExactlyOneParameter(context, "bar", "d", &value)); TF_EXPECT_OK(context.GetOneStringParameter("bar", "d", &value));
EXPECT_EQ("c", value); EXPECT_EQ("c", value);
EXPECT_FALSE(GetExactlyOneParameter(context, "foo", "d", &value).ok()); EXPECT_FALSE(context.GetOneStringParameter("foo", "d", &value).ok());
TF_EXPECT_OK(GetExactlyOneParameter(context, "not_present", "d", &value)); TF_EXPECT_OK(context.GetOneStringParameter("not_present", "d", &value));
EXPECT_EQ("d", value); EXPECT_EQ("d", value);
} }
void TestGetOneIntParameter() {
TransformFuncContext context;
context.params.insert({"foo", {"10", "20"}});
context.params.insert({"bar", {"-23"}});
context.params.insert({"not_a_number", {"not_numerical"}});
context.params.insert({"float", {"-23.232323"}});
int64 value;
TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value));
EXPECT_EQ(-23, value);
EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok());
TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value));
EXPECT_EQ(10, value);
EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok());
EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok());
}
void TestGetOneFloatParameter() {
TransformFuncContext context;
context.params.insert({"foo", {"10.0", "20.0"}});
context.params.insert({"bar", {"-23.2323"}});
context.params.insert({"not_a_number", {"not_numerical"}});
float value;
TF_EXPECT_OK(context.GetOneFloatParameter("bar", 0, &value));
EXPECT_NEAR(-23.2323f, value, 1e-5f);
EXPECT_FALSE(context.GetOneFloatParameter("foo", 0, &value).ok());
TF_EXPECT_OK(context.GetOneFloatParameter("not_present", 10.5f, &value));
EXPECT_NEAR(10.5f, value, 1e-5f);
EXPECT_FALSE(context.GetOneFloatParameter("not_a_number", 0, &value).ok());
}
void TestGetOneBoolParameter() {
TransformFuncContext context;
context.params.insert({"foo", {"true", "false"}});
context.params.insert({"true", {"true"}});
context.params.insert({"false", {"false"}});
context.params.insert({"one", {"1"}});
context.params.insert({"zero", {"0"}});
context.params.insert({"not_a_bool", {"not_boolean"}});
bool value;
EXPECT_FALSE(context.GetOneBoolParameter("foo", 0, &value).ok());
value = false;
TF_EXPECT_OK(context.GetOneBoolParameter("true", false, &value));
EXPECT_TRUE(value);
value = true;
TF_EXPECT_OK(context.GetOneBoolParameter("false", true, &value));
EXPECT_FALSE(value);
value = false;
TF_EXPECT_OK(context.GetOneBoolParameter("one", false, &value));
EXPECT_TRUE(value);
value = true;
TF_EXPECT_OK(context.GetOneBoolParameter("zero", true, &value));
EXPECT_FALSE(value);
EXPECT_FALSE(context.GetOneBoolParameter("not_a_bool", false, &value).ok());
value = false;
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
EXPECT_TRUE(value);
}
void TestLoadTextOrBinaryGraphFile() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
auto root = tensorflow::Scope::NewRootScope();
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
const string text_file =
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
const string binary_file =
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
TF_ASSERT_OK(
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
GraphDef text_graph_def;
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
string text_diff;
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
<< text_diff;
GraphDef binary_graph_def;
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
string binary_diff;
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
<< binary_diff;
GraphDef no_graph_def;
EXPECT_FALSE(
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
.ok());
GraphDef bogus_graph_def;
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
}
}; };
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
@ -1012,8 +1123,22 @@ TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); } TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) { TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
TestGetExactlyOneParameter(); TestGetOneStringParameter();
}
TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); }
TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
TestGetOneFloatParameter();
}
TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
TestGetOneBoolParameter();
}
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
TestLoadTextOrBinaryGraphFile();
} }
} // namespace graph_transforms } // namespace graph_transforms