Add new transforms and compare_graph tool
Change: 144024968
This commit is contained in:
parent
bfa0fa079d
commit
80679a6a74
@ -43,6 +43,7 @@ tf_cc_test(
|
||||
":transform_utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -55,6 +56,7 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "transforms_lib",
|
||||
srcs = [
|
||||
"add_default_attributes.cc",
|
||||
"fold_batch_norms.cc",
|
||||
"fold_constants_lib.cc",
|
||||
"fold_old_batch_norms.cc",
|
||||
@ -68,6 +70,7 @@ cc_library(
|
||||
"rename_attribute.cc",
|
||||
"rename_op.cc",
|
||||
"round_weights.cc",
|
||||
"set_device.cc",
|
||||
"sort_by_execution_order.cc",
|
||||
"strip_unused_nodes.cc",
|
||||
],
|
||||
@ -80,6 +83,7 @@ cc_library(
|
||||
":transform_utils",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -93,6 +97,7 @@ tf_cc_test(
|
||||
name = "transforms_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"add_default_attributes_test.cc",
|
||||
"fold_batch_norms_test.cc",
|
||||
"fold_constants_test.cc",
|
||||
"fold_old_batch_norms_test.cc",
|
||||
@ -106,6 +111,7 @@ tf_cc_test(
|
||||
"rename_attribute_test.cc",
|
||||
"rename_op_test.cc",
|
||||
"round_weights_test.cc",
|
||||
"set_device_test.cc",
|
||||
"sort_by_execution_order_test.cc",
|
||||
"strip_unused_nodes_test.cc",
|
||||
],
|
||||
@ -209,3 +215,18 @@ cc_binary(
|
||||
":summarize_graph_main_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "compare_graphs",
|
||||
srcs = ["compare_graphs.cc"],
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":transform_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
@ -12,6 +12,7 @@
|
||||
* [Shrinking File Size](#shrinking-file-size)
|
||||
* [Eight-bit Calculations](#eight-bit-calculations)
|
||||
* [Transform Reference](#transform-reference)
|
||||
* [add_default_attributes](#add_default_attributes)
|
||||
* [fold_batch_norms](#fold_batch_norms)
|
||||
* [fold_constants](#fold_constants)
|
||||
* [fold_old_batch_norms](#fold_old_batch_norms)
|
||||
@ -26,6 +27,7 @@
|
||||
* [rename_attribute](#rename_attribute)
|
||||
* [rename_op](#rename_op)
|
||||
* [round_weights](#round_weights)
|
||||
* [set_device](#set_device)
|
||||
* [sort_by_execution_order](#sort_by_execution_order)
|
||||
* [strip_unused_nodes](#strip_unused_nodes)
|
||||
* [Writing Your Own Transforms](#writing-your-own-transforms)
|
||||
@ -318,6 +320,18 @@ logged and the transform skipped. This is especially useful for optional
|
||||
transforms where version errors or other unimportant problems may trigger an
|
||||
error.
|
||||
|
||||
### add_default_attributes
|
||||
|
||||
Args: None
|
||||
|
||||
When attributes are added to ops in new versions of TensorFlow, they often have
|
||||
defaults to ensure backwards compatible behavior with their original versions.
|
||||
These defaults usually get added when the graph is loaded by the runtime, but if
|
||||
your model is going to be processed outside of the main TensorFlow framework it
|
||||
can be useful to run this update process as a transform. This process finds any
|
||||
op attributes that are defined in the current TensorFlow list of ops but not
|
||||
within the saved model, and sets them to the defined default for that attribute.
|
||||
|
||||
### fold_batch_norms
|
||||
|
||||
Args: None
|
||||
@ -516,6 +530,22 @@ between the largest and smallest values present. This is useful when you'll be
|
||||
deploying on mobile, and you want a model that will compress effectively. See
|
||||
[shrinking file size](#shrinking-file-size) for more details.
|
||||
|
||||
### set_device
|
||||
|
||||
Args:
|
||||
|
||||
* device: What device to assign to ops.
|
||||
* if_default: If this is true, only assign to ops with empty existing devices.
|
||||
|
||||
Updates nodes to use the specified device. A device is a way to tell the code
|
||||
that executes the graph which piece of hardware it should run particular nodes
|
||||
on. The right assignment to use may change between training and deployment, so
|
||||
this transform (and [remove_device](#remove_device)) provide a way of updating
|
||||
the placement. If the `is_default` parameter is set, then only ops that don't
|
||||
have a device assigned already will be updated. This is mostly useful for
|
||||
preprocessing of graphs for other stages that expect all ops to have an explicit
|
||||
device assigned.
|
||||
|
||||
### sort_by_execution_order
|
||||
|
||||
Args: None\
|
||||
@ -844,23 +874,13 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps`
|
||||
parameter:
|
||||
|
||||
```C++
|
||||
string num_steps_string;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
|
||||
int32 num_steps;
|
||||
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
|
||||
return errors::InvalidArgument(
|
||||
"Couldn't interpret the num_steps argument to round_weights as a "
|
||||
"number:",
|
||||
num_steps_string);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
|
||||
```
|
||||
|
||||
Something to notice here is that you have to convert the string to an integer,
|
||||
and if the conversion fails you need to raise a meaningful error through the
|
||||
status result of the transform. Also, we're using a helper function which raises
|
||||
an error if the parameter is present multiple times, and uses a default if the
|
||||
user hasn't specified it.
|
||||
If the conversion fails or the parameter occurs more than once the helper
|
||||
function will raise a meaningful error through the status result of the
|
||||
transform. If the parameter isn't specified at all then the default will be
|
||||
used.
|
||||
|
||||
### Function Libraries
|
||||
|
||||
|
40
tensorflow/tools/graph_transforms/add_default_attributes.cc
Normal file
40
tensorflow/tools/graph_transforms/add_default_attributes.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Sets any parameters not specified in a node to their defaults.
|
||||
Status AddDefaultAttributes(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
// Find all of the ops that are currently defined.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(),
|
||||
input_graph_def.library()));
|
||||
// Works in-place, so copy over the original graph.
|
||||
*output_graph_def = input_graph_def;
|
||||
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(output_graph_def, *flib_def, 0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("add_default_attributes", AddDefaultAttributes);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -0,0 +1,74 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status AddDefaultAttributes(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
class AddDefaultAttributesTest : public ::testing::Test {
|
||||
protected:
|
||||
void TestAddDefaultAttributes() {
|
||||
GraphDef graph_def;
|
||||
|
||||
NodeDef* lrn_node1 = graph_def.add_node();
|
||||
lrn_node1->set_name("lrn_node1");
|
||||
lrn_node1->set_op("LRN");
|
||||
|
||||
NodeDef* lrn_node2 = graph_def.add_node();
|
||||
lrn_node2->set_name("lrn_node2");
|
||||
lrn_node2->set_op("LRN");
|
||||
SetNodeAttr("depth_radius", 7, lrn_node2);
|
||||
SetNodeAttr("bias", 2.0f, lrn_node2);
|
||||
SetNodeAttr("alpha", 2.0f, lrn_node2);
|
||||
SetNodeAttr("beta", 1.0f, lrn_node2);
|
||||
|
||||
GraphDef result;
|
||||
TF_ASSERT_OK(AddDefaultAttributes(graph_def, {}, &result));
|
||||
|
||||
std::map<string, const NodeDef*> nodes;
|
||||
MapNamesToNodes(result, &nodes);
|
||||
EXPECT_EQ(5, nodes.at("lrn_node1")->attr().at("depth_radius").i());
|
||||
EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("bias").f(), 1e-5f);
|
||||
EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("alpha").f(), 1e-5f);
|
||||
EXPECT_NEAR(0.5f, nodes.at("lrn_node1")->attr().at("beta").f(), 1e-5f);
|
||||
EXPECT_EQ(7, nodes.at("lrn_node2")->attr().at("depth_radius").i());
|
||||
EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("bias").f(), 1e-5f);
|
||||
EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("alpha").f(), 1e-5f);
|
||||
EXPECT_NEAR(1.0f, nodes.at("lrn_node2")->attr().at("beta").f(), 1e-5f);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(AddDefaultAttributesTest, TestAddDefaultAttributes) {
|
||||
TestAddDefaultAttributes();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
79
tensorflow/tools/graph_transforms/compare_graphs.cc
Normal file
79
tensorflow/tools/graph_transforms/compare_graphs.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Compares two TensorFlow graphs to see if their meaning is the same. This is a
|
||||
// semantic comparison that's intended to show whether the graphs should produce
|
||||
// the same results, and so ignores details like version numbers or node
|
||||
// ordering that don't affect the output. To use it, run something like this:
|
||||
//
|
||||
// bazel build tensorflow/tools/graph_transforms:compare_graphs
|
||||
// bazel-bin/tensorflow/tools/graph_transforms/compare_graphs a.pb b.pb
|
||||
//
|
||||
// The return value is 0 if the graphs are equal, 1 if they're different, and -1
|
||||
// if there was a problem.
|
||||
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
namespace {
|
||||
|
||||
int ParseFlagsAndCompareGraphs(int argc, char* argv[]) {
|
||||
// We need to call this to set up global state for TensorFlow.
|
||||
port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
if (argc != 3) {
|
||||
LOG(ERROR) << "compare_graphs expects two file names as arguments";
|
||||
return -1;
|
||||
}
|
||||
|
||||
GraphDef a;
|
||||
Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a);
|
||||
if (!a_load_status.ok()) {
|
||||
LOG(ERROR) << "Loading graph '" << argv[1] << "' failed with "
|
||||
<< a_load_status.error_message();
|
||||
return -1;
|
||||
}
|
||||
|
||||
GraphDef b;
|
||||
Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b);
|
||||
if (!b_load_status.ok()) {
|
||||
LOG(ERROR) << "Loading graph '" << argv[2] << "' failed with "
|
||||
<< b_load_status.error_message();
|
||||
return -1;
|
||||
}
|
||||
|
||||
string diff;
|
||||
if (EqualGraphDef(a, b, &diff)) {
|
||||
std::cout << "Graphs are equal." << std::endl;
|
||||
return 0;
|
||||
} else {
|
||||
std::cout << diff << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return tensorflow::graph_transforms::ParseFlagsAndCompareGraphs(argc, argv);
|
||||
}
|
@ -40,10 +40,6 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||
required_nodes.insert(output);
|
||||
}
|
||||
|
||||
for (const string& required_node : required_nodes) {
|
||||
LOG(INFO) << "required_node=" << required_node;
|
||||
}
|
||||
|
||||
const string valid_chars =
|
||||
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||
const int64 chars_size = valid_chars.size();
|
||||
|
@ -157,22 +157,8 @@ Status ExtractRangeFromParams(const TransformFuncContext& context,
|
||||
return errors::InvalidArgument("You must pass both ", min_name, " and ",
|
||||
max_name, " into quantize_nodes");
|
||||
}
|
||||
std::vector<string> min_strings = context.params.at(min_name);
|
||||
std::vector<string> max_strings = context.params.at(max_name);
|
||||
if ((min_strings.size() != 1) || (max_strings.size() != 1)) {
|
||||
return errors::InvalidArgument("You must pass a single ", min_name,
|
||||
" and single ", max_name,
|
||||
" value into "
|
||||
"quantize_nodes");
|
||||
}
|
||||
if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) {
|
||||
return errors::InvalidArgument("Couldn't decode ", min_name,
|
||||
" as a number: ", min_strings[0]);
|
||||
}
|
||||
if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) {
|
||||
return errors::InvalidArgument("Couldn't decode ", max_name,
|
||||
" as a number: ", max_strings[0]);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
|
||||
TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -33,16 +33,8 @@ namespace graph_transforms {
|
||||
Status RoundWeights(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
string num_steps_string;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
|
||||
int32 num_steps;
|
||||
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
|
||||
return errors::InvalidArgument(
|
||||
"Couldn't interpret the num_steps argument to round_weights as a "
|
||||
"number:",
|
||||
num_steps_string);
|
||||
}
|
||||
int64 num_steps;
|
||||
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
|
||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||
input_graph_def, {"Const"},
|
||||
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||
|
46
tensorflow/tools/graph_transforms/set_device.cc
Normal file
46
tensorflow/tools/graph_transforms/set_device.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Sets the device field of ops in the graph.
|
||||
Status SetDevice(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
string new_device;
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter("device", "", &new_device));
|
||||
bool if_default;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneBoolParameter("if_default", false, &if_default));
|
||||
|
||||
output_graph_def->Clear();
|
||||
for (const NodeDef& node : input_graph_def.node()) {
|
||||
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||
new_node->CopyFrom(node);
|
||||
if (!if_default || (node.device() == "")) {
|
||||
new_node->set_device(new_device);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("set_device", SetDevice);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
127
tensorflow/tools/graph_transforms/set_device_test.cc
Normal file
127
tensorflow/tools/graph_transforms/set_device_test.cc
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status SetDevice(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
namespace {
|
||||
GraphDef CreateDeviceGraph() {
|
||||
GraphDef graph_def;
|
||||
|
||||
NodeDef* mul_node1 = graph_def.add_node();
|
||||
mul_node1->set_name("mul_node1");
|
||||
mul_node1->set_op("Mul");
|
||||
mul_node1->set_device("/device:CPU:0");
|
||||
mul_node1->add_input("add_node2");
|
||||
mul_node1->add_input("add_node3");
|
||||
|
||||
NodeDef* add_node2 = graph_def.add_node();
|
||||
add_node2->set_name("add_node2");
|
||||
add_node2->set_op("Add");
|
||||
add_node2->add_input("const_node1");
|
||||
add_node2->add_input("const_node2");
|
||||
add_node2->set_device("/device:GPU:1");
|
||||
|
||||
NodeDef* add_node3 = graph_def.add_node();
|
||||
add_node3->set_name("add_node3");
|
||||
add_node3->set_op("Add");
|
||||
add_node3->add_input("const_node1");
|
||||
add_node3->add_input("const_node3");
|
||||
|
||||
NodeDef* const_node1 = graph_def.add_node();
|
||||
const_node1->set_name("const_node1");
|
||||
const_node1->set_op("Const");
|
||||
|
||||
NodeDef* const_node2 = graph_def.add_node();
|
||||
const_node2->set_name("const_node2");
|
||||
const_node2->set_op("Const");
|
||||
|
||||
NodeDef* const_node3 = graph_def.add_node();
|
||||
const_node3->set_name("const_node3");
|
||||
const_node3->set_op("Const");
|
||||
|
||||
NodeDef* add_node4 = graph_def.add_node();
|
||||
add_node4->set_name("add_node4");
|
||||
add_node4->set_op("Add");
|
||||
add_node4->add_input("add_node2");
|
||||
add_node4->add_input("add_node3");
|
||||
|
||||
return graph_def;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(SetDeviceTest, TestSetDevice) {
|
||||
GraphDef graph_def = CreateDeviceGraph();
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"mul_node1"};
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{"device", {string("/device:CPU:0")}}));
|
||||
TF_ASSERT_OK(SetDevice(graph_def, context, &result));
|
||||
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node2")->device());
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node3")->device());
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node1")->device());
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node2")->device());
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node3")->device());
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node4")->device());
|
||||
}
|
||||
|
||||
TEST(SetDeviceTest, TestSetDeviceIfDefault) {
|
||||
GraphDef graph_def = CreateDeviceGraph();
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"mul_node1"};
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{"device", {string("/device:GPU:0")}}));
|
||||
context.params.insert(
|
||||
std::pair<string, std::vector<string>>({"if_default", {string("true")}}));
|
||||
TF_ASSERT_OK(SetDevice(graph_def, context, &result));
|
||||
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
|
||||
EXPECT_EQ("/device:GPU:1", node_lookup.at("add_node2")->device());
|
||||
EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node3")->device());
|
||||
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node1")->device());
|
||||
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node2")->device());
|
||||
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node3")->device());
|
||||
EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node4")->device());
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -129,12 +129,15 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
||||
string inputs_string = "";
|
||||
string outputs_string = "";
|
||||
string transforms_string = "";
|
||||
bool output_as_text = false;
|
||||
std::vector<Flag> flag_list = {
|
||||
Flag("in_graph", &in_graph, "input graph file name"),
|
||||
Flag("out_graph", &out_graph, "output graph file name"),
|
||||
Flag("inputs", &inputs_string, "inputs"),
|
||||
Flag("outputs", &outputs_string, "outputs"),
|
||||
Flag("transforms", &transforms_string, "list of transforms"),
|
||||
Flag("output_as_text", &output_as_text,
|
||||
"whether to write the graph in text protobuf format"),
|
||||
};
|
||||
string usage = Flags::Usage(argv[0], flag_list);
|
||||
usage += "\nTransforms are:\n";
|
||||
@ -185,7 +188,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
||||
}
|
||||
|
||||
GraphDef graph_def;
|
||||
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
|
||||
Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
|
||||
if (!load_status.ok()) {
|
||||
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
|
||||
<< load_status.error_message();
|
||||
@ -202,7 +205,12 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
|
||||
Status save_status;
|
||||
if (output_as_text) {
|
||||
save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
|
||||
} else {
|
||||
save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
|
||||
}
|
||||
if (!save_status.ok()) {
|
||||
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
|
||||
<< save_status.error_message();
|
||||
|
@ -583,23 +583,45 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int CountParameters(const TransformFuncContext& context, const string& name) {
|
||||
if (context.params.count(name)) {
|
||||
return context.params.at(name).size();
|
||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
|
||||
string file_data;
|
||||
Status load_file_status =
|
||||
ReadFileToString(Env::Default(), file_name, &file_data);
|
||||
if (!load_file_status.ok()) {
|
||||
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
|
||||
return load_file_status;
|
||||
}
|
||||
// Try to load in binary format first, and then try ascii if that fails.
|
||||
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
|
||||
if (!load_status.ok()) {
|
||||
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
|
||||
load_status = Status::OK();
|
||||
} else {
|
||||
errors::AppendToMessage(&load_status,
|
||||
" (both text and binary parsing failed for file ",
|
||||
file_name, ")");
|
||||
}
|
||||
}
|
||||
return load_status;
|
||||
}
|
||||
|
||||
int TransformFuncContext::CountParameters(const string& name) const {
|
||||
if (params.count(name)) {
|
||||
return params.at(name).size();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
Status GetExactlyOneParameter(const TransformFuncContext& context,
|
||||
const string& name, const string& default_value,
|
||||
string* result) {
|
||||
const int params_count = CountParameters(context, name);
|
||||
Status TransformFuncContext::GetOneStringParameter(const string& name,
|
||||
const string& default_value,
|
||||
string* result) const {
|
||||
const int params_count = CountParameters(name);
|
||||
if (params_count == 0) {
|
||||
*result = default_value;
|
||||
return Status::OK();
|
||||
} else if (params_count == 1) {
|
||||
*result = context.params.at(name).at(0);
|
||||
*result = params.at(name).at(0);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::InvalidArgument("Expected a single '", name,
|
||||
@ -608,5 +630,62 @@ Status GetExactlyOneParameter(const TransformFuncContext& context,
|
||||
}
|
||||
}
|
||||
|
||||
Status TransformFuncContext::GetOneIntParameter(const string& name,
|
||||
int64 default_value,
|
||||
int64* result) const {
|
||||
const int params_count = CountParameters(name);
|
||||
if (params_count == 0) {
|
||||
*result = default_value;
|
||||
return Status::OK();
|
||||
}
|
||||
string string_value;
|
||||
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||
if (!strings::safe_strto64(StringPiece(string_value), result)) {
|
||||
return errors::InvalidArgument("Couldn't interpret the ", name,
|
||||
" argument as a number:", string_value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformFuncContext::GetOneFloatParameter(const string& name,
|
||||
float default_value,
|
||||
float* result) const {
|
||||
const int params_count = CountParameters(name);
|
||||
if (params_count == 0) {
|
||||
*result = default_value;
|
||||
return Status::OK();
|
||||
}
|
||||
string string_value;
|
||||
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||
if (!strings::safe_strtof(string_value.c_str(), result)) {
|
||||
return errors::InvalidArgument(
|
||||
"Couldn't interpret the ", name,
|
||||
" argument as a float number:", string_value);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransformFuncContext::GetOneBoolParameter(const string& name,
|
||||
bool default_value,
|
||||
bool* result) const {
|
||||
const int params_count = CountParameters(name);
|
||||
if (params_count == 0) {
|
||||
*result = default_value;
|
||||
return Status::OK();
|
||||
}
|
||||
string string_value;
|
||||
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||
if (string_value == "true" || string_value == "1") {
|
||||
*result = true;
|
||||
} else if (string_value == "false" || string_value == "0") {
|
||||
*result = false;
|
||||
} else {
|
||||
return errors::InvalidArgument("Couldn't interpret the ", name,
|
||||
" argument as a boolean:", string_value,
|
||||
" (expected true, false, 0 or 1)");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
||||
|
@ -128,6 +128,10 @@ Status IsGraphValid(const GraphDef& graph_def);
|
||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||
DataTypeVector* outputs);
|
||||
|
||||
// First tries to load the file as a text protobuf, if that fails tries to parse
|
||||
// it as a binary protobuf, and returns an error if both fail.
|
||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
|
||||
|
||||
// This is used to spot particular subgraphs in a larger model. To use it,
|
||||
// create a pattern like:
|
||||
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
||||
@ -213,16 +217,33 @@ struct TransformFuncContext {
|
||||
std::vector<string> input_names;
|
||||
std::vector<string> output_names;
|
||||
TransformFuncParameters params;
|
||||
|
||||
// Returns how many occurrences of the given parameter are present.
|
||||
int CountParameters(const string& name) const;
|
||||
|
||||
// Gets a single instance of a parameter, using a default if it's not present.
|
||||
Status GetOneStringParameter(const string& name, const string& default_value,
|
||||
string* result) const;
|
||||
|
||||
// Gets a single occurrence of a parameter as an integer, falling back to a
|
||||
// default if it isn't present and returning an error if it isn't convertible
|
||||
// to a number.
|
||||
Status GetOneIntParameter(const string& name, int64 default_value,
|
||||
int64* result) const;
|
||||
|
||||
// Gets a single occurrence of a parameter as a floating point number, falling
|
||||
// back to a default if it isn't present and returning an error if it isn't
|
||||
// convertible to a number.
|
||||
Status GetOneFloatParameter(const string& name, float default_value,
|
||||
float* result) const;
|
||||
|
||||
// Gets a single occurrence of a parameter as a boolean, falling back to a
|
||||
// default if it isn't present and returning an error if it's not one of
|
||||
// "true", "1", "false", or "0".
|
||||
Status GetOneBoolParameter(const string& name, bool default_value,
|
||||
bool* result) const;
|
||||
};
|
||||
|
||||
// Returns how many occurrences of the given parameter are present.
|
||||
int CountParameters(const TransformFuncContext& context, const string& name);
|
||||
|
||||
// Gets a simple occurrence of a parameter, using a default if it isn't present.
|
||||
Status GetExactlyOneParameter(const TransformFuncContext& context,
|
||||
const string& name, const string& default_value,
|
||||
string* result);
|
||||
|
||||
// This is the function API for all graph transformations, taking an input
|
||||
// GraphDef and other arguments, and returning a transformed GraphDef.
|
||||
typedef std::function<Status(const GraphDef&,
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -924,22 +926,131 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"a", "b"}});
|
||||
context.params.insert({"bar", {"c"}});
|
||||
EXPECT_EQ(2, CountParameters(context, "foo"));
|
||||
EXPECT_EQ(1, CountParameters(context, "bar"));
|
||||
EXPECT_EQ(0, CountParameters(context, "not_present"));
|
||||
EXPECT_EQ(2, context.CountParameters("foo"));
|
||||
EXPECT_EQ(1, context.CountParameters("bar"));
|
||||
EXPECT_EQ(0, context.CountParameters("not_present"));
|
||||
}
|
||||
|
||||
void TestGetExactlyOneParameter() {
|
||||
void TestGetOneStringParameter() {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"a", "b"}});
|
||||
context.params.insert({"bar", {"c"}});
|
||||
string value;
|
||||
TF_EXPECT_OK(GetExactlyOneParameter(context, "bar", "d", &value));
|
||||
TF_EXPECT_OK(context.GetOneStringParameter("bar", "d", &value));
|
||||
EXPECT_EQ("c", value);
|
||||
EXPECT_FALSE(GetExactlyOneParameter(context, "foo", "d", &value).ok());
|
||||
TF_EXPECT_OK(GetExactlyOneParameter(context, "not_present", "d", &value));
|
||||
EXPECT_FALSE(context.GetOneStringParameter("foo", "d", &value).ok());
|
||||
TF_EXPECT_OK(context.GetOneStringParameter("not_present", "d", &value));
|
||||
EXPECT_EQ("d", value);
|
||||
}
|
||||
|
||||
void TestGetOneIntParameter() {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"10", "20"}});
|
||||
context.params.insert({"bar", {"-23"}});
|
||||
context.params.insert({"not_a_number", {"not_numerical"}});
|
||||
context.params.insert({"float", {"-23.232323"}});
|
||||
int64 value;
|
||||
TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value));
|
||||
EXPECT_EQ(-23, value);
|
||||
EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok());
|
||||
TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value));
|
||||
EXPECT_EQ(10, value);
|
||||
EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok());
|
||||
EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok());
|
||||
}
|
||||
|
||||
void TestGetOneFloatParameter() {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"10.0", "20.0"}});
|
||||
context.params.insert({"bar", {"-23.2323"}});
|
||||
context.params.insert({"not_a_number", {"not_numerical"}});
|
||||
float value;
|
||||
TF_EXPECT_OK(context.GetOneFloatParameter("bar", 0, &value));
|
||||
EXPECT_NEAR(-23.2323f, value, 1e-5f);
|
||||
EXPECT_FALSE(context.GetOneFloatParameter("foo", 0, &value).ok());
|
||||
TF_EXPECT_OK(context.GetOneFloatParameter("not_present", 10.5f, &value));
|
||||
EXPECT_NEAR(10.5f, value, 1e-5f);
|
||||
EXPECT_FALSE(context.GetOneFloatParameter("not_a_number", 0, &value).ok());
|
||||
}
|
||||
|
||||
void TestGetOneBoolParameter() {
|
||||
TransformFuncContext context;
|
||||
context.params.insert({"foo", {"true", "false"}});
|
||||
context.params.insert({"true", {"true"}});
|
||||
context.params.insert({"false", {"false"}});
|
||||
context.params.insert({"one", {"1"}});
|
||||
context.params.insert({"zero", {"0"}});
|
||||
context.params.insert({"not_a_bool", {"not_boolean"}});
|
||||
|
||||
bool value;
|
||||
EXPECT_FALSE(context.GetOneBoolParameter("foo", 0, &value).ok());
|
||||
|
||||
value = false;
|
||||
TF_EXPECT_OK(context.GetOneBoolParameter("true", false, &value));
|
||||
EXPECT_TRUE(value);
|
||||
|
||||
value = true;
|
||||
TF_EXPECT_OK(context.GetOneBoolParameter("false", true, &value));
|
||||
EXPECT_FALSE(value);
|
||||
|
||||
value = false;
|
||||
TF_EXPECT_OK(context.GetOneBoolParameter("one", false, &value));
|
||||
EXPECT_TRUE(value);
|
||||
|
||||
value = true;
|
||||
TF_EXPECT_OK(context.GetOneBoolParameter("zero", true, &value));
|
||||
EXPECT_FALSE(value);
|
||||
|
||||
EXPECT_FALSE(context.GetOneBoolParameter("not_a_bool", false, &value).ok());
|
||||
|
||||
value = false;
|
||||
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
|
||||
EXPECT_TRUE(value);
|
||||
}
|
||||
|
||||
void TestLoadTextOrBinaryGraphFile() {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
const int width = 10;
|
||||
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||
test::FillIota<float>(&a_data, 1.0f);
|
||||
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
const string text_file =
|
||||
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
|
||||
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
|
||||
|
||||
const string binary_file =
|
||||
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
|
||||
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
|
||||
|
||||
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
|
||||
TF_ASSERT_OK(
|
||||
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
|
||||
|
||||
GraphDef text_graph_def;
|
||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
|
||||
string text_diff;
|
||||
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
|
||||
<< text_diff;
|
||||
|
||||
GraphDef binary_graph_def;
|
||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
|
||||
string binary_diff;
|
||||
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
|
||||
<< binary_diff;
|
||||
|
||||
GraphDef no_graph_def;
|
||||
EXPECT_FALSE(
|
||||
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
|
||||
.ok());
|
||||
|
||||
GraphDef bogus_graph_def;
|
||||
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
||||
@ -1012,8 +1123,22 @@ TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) {
|
||||
TestGetExactlyOneParameter();
|
||||
TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
|
||||
TestGetOneStringParameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); }
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
|
||||
TestGetOneFloatParameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
|
||||
TestGetOneBoolParameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
|
||||
TestLoadTextOrBinaryGraphFile();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
|
Loading…
Reference in New Issue
Block a user