diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 0e8da270170..3bd2cc9ac92 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -43,6 +43,7 @@ tf_cc_test( ":transform_utils", "//tensorflow/cc:cc_ops", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -55,6 +56,7 @@ tf_cc_test( cc_library( name = "transforms_lib", srcs = [ + "add_default_attributes.cc", "fold_batch_norms.cc", "fold_constants_lib.cc", "fold_old_batch_norms.cc", @@ -68,6 +70,7 @@ cc_library( "rename_attribute.cc", "rename_op.cc", "round_weights.cc", + "set_device.cc", "sort_by_execution_order.cc", "strip_unused_nodes.cc", ], @@ -80,6 +83,7 @@ cc_library( ":transform_utils", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -93,6 +97,7 @@ tf_cc_test( name = "transforms_test", size = "small", srcs = [ + "add_default_attributes_test.cc", "fold_batch_norms_test.cc", "fold_constants_test.cc", "fold_old_batch_norms_test.cc", @@ -106,6 +111,7 @@ tf_cc_test( "rename_attribute_test.cc", "rename_op_test.cc", "round_weights_test.cc", + "set_device_test.cc", "sort_by_execution_order_test.cc", "strip_unused_nodes_test.cc", ], @@ -209,3 +215,18 @@ cc_binary( ":summarize_graph_main_lib", ], ) + +cc_binary( + name = "compare_graphs", + srcs = ["compare_graphs.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":transform_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index f6087727888..334c65117ee 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -12,6 +12,7 @@ * [Shrinking File Size](#shrinking-file-size) * [Eight-bit Calculations](#eight-bit-calculations) * [Transform Reference](#transform-reference) + * [add_default_attributes](#add_default_attributes) * [fold_batch_norms](#fold_batch_norms) * [fold_constants](#fold_constants) * [fold_old_batch_norms](#fold_old_batch_norms) @@ -26,6 +27,7 @@ * [rename_attribute](#rename_attribute) * [rename_op](#rename_op) * [round_weights](#round_weights) + * [set_device](#set_device) * [sort_by_execution_order](#sort_by_execution_order) * [strip_unused_nodes](#strip_unused_nodes) * [Writing Your Own Transforms](#writing-your-own-transforms) @@ -318,6 +320,18 @@ logged and the transform skipped. This is especially useful for optional transforms where version errors or other unimportant problems may trigger an error. +### add_default_attributes + +Args: None + +When attributes are added to ops in new versions of TensorFlow, they often have +defaults to ensure backwards compatible behavior with their original versions. +These defaults usually get added when the graph is loaded by the runtime, but if +your model is going to be processed outside of the main TensorFlow framework it +can be useful to run this update process as a transform. This process finds any +op attributes that are defined in the current TensorFlow list of ops but not +within the saved model, and sets them to the defined default for that attribute. + ### fold_batch_norms Args: None @@ -516,6 +530,22 @@ between the largest and smallest values present. This is useful when you'll be deploying on mobile, and you want a model that will compress effectively. See [shrinking file size](#shrinking-file-size) for more details. +### set_device + +Args: + +* device: What device to assign to ops. +* if_default: If this is true, only assign to ops with empty existing devices. + +Updates nodes to use the specified device. A device is a way to tell the code +that executes the graph which piece of hardware it should run particular nodes +on. The right assignment to use may change between training and deployment, so +this transform (and [remove_device](#remove_device)) provide a way of updating +the placement. If the `is_default` parameter is set, then only ops that don't +have a device assigned already will be updated. This is mostly useful for +preprocessing of graphs for other stages that expect all ops to have an explicit +device assigned. + ### sort_by_execution_order Args: None\ @@ -844,23 +874,13 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps` parameter: ```C++ -string num_steps_string; -TF_RETURN_IF_ERROR( - GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string)); -int32 num_steps; -if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) { - return errors::InvalidArgument( - "Couldn't interpret the num_steps argument to round_weights as a " - "number:", - num_steps_string); -} +TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps)); ``` -Something to notice here is that you have to convert the string to an integer, -and if the conversion fails you need to raise a meaningful error through the -status result of the transform. Also, we're using a helper function which raises -an error if the parameter is present multiple times, and uses a default if the -user hasn't specified it. +If the conversion fails or the parameter occurs more than once the helper +function will raise a meaningful error through the status result of the +transform. If the parameter isn't specified at all then the default will be +used. ### Function Libraries diff --git a/tensorflow/tools/graph_transforms/add_default_attributes.cc b/tensorflow/tools/graph_transforms/add_default_attributes.cc new file mode 100644 index 00000000000..3b20971de1a --- /dev/null +++ b/tensorflow/tools/graph_transforms/add_default_attributes.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Sets any parameters not specified in a node to their defaults. +Status AddDefaultAttributes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + // Find all of the ops that are currently defined. + std::unique_ptr flib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), + input_graph_def.library())); + // Works in-place, so copy over the original graph. + *output_graph_def = input_graph_def; + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(output_graph_def, *flib_def, 0)); + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("add_default_attributes", AddDefaultAttributes); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/add_default_attributes_test.cc b/tensorflow/tools/graph_transforms/add_default_attributes_test.cc new file mode 100644 index 00000000000..a0f1d3162a5 --- /dev/null +++ b/tensorflow/tools/graph_transforms/add_default_attributes_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status AddDefaultAttributes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class AddDefaultAttributesTest : public ::testing::Test { + protected: + void TestAddDefaultAttributes() { + GraphDef graph_def; + + NodeDef* lrn_node1 = graph_def.add_node(); + lrn_node1->set_name("lrn_node1"); + lrn_node1->set_op("LRN"); + + NodeDef* lrn_node2 = graph_def.add_node(); + lrn_node2->set_name("lrn_node2"); + lrn_node2->set_op("LRN"); + SetNodeAttr("depth_radius", 7, lrn_node2); + SetNodeAttr("bias", 2.0f, lrn_node2); + SetNodeAttr("alpha", 2.0f, lrn_node2); + SetNodeAttr("beta", 1.0f, lrn_node2); + + GraphDef result; + TF_ASSERT_OK(AddDefaultAttributes(graph_def, {}, &result)); + + std::map nodes; + MapNamesToNodes(result, &nodes); + EXPECT_EQ(5, nodes.at("lrn_node1")->attr().at("depth_radius").i()); + EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("bias").f(), 1e-5f); + EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("alpha").f(), 1e-5f); + EXPECT_NEAR(0.5f, nodes.at("lrn_node1")->attr().at("beta").f(), 1e-5f); + EXPECT_EQ(7, nodes.at("lrn_node2")->attr().at("depth_radius").i()); + EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("bias").f(), 1e-5f); + EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("alpha").f(), 1e-5f); + EXPECT_NEAR(1.0f, nodes.at("lrn_node2")->attr().at("beta").f(), 1e-5f); + } +}; + +TEST_F(AddDefaultAttributesTest, TestAddDefaultAttributes) { + TestAddDefaultAttributes(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/compare_graphs.cc b/tensorflow/tools/graph_transforms/compare_graphs.cc new file mode 100644 index 00000000000..67198da5e92 --- /dev/null +++ b/tensorflow/tools/graph_transforms/compare_graphs.cc @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Compares two TensorFlow graphs to see if their meaning is the same. This is a +// semantic comparison that's intended to show whether the graphs should produce +// the same results, and so ignores details like version numbers or node +// ordering that don't affect the output. To use it, run something like this: +// +// bazel build tensorflow/tools/graph_transforms:compare_graphs +// bazel-bin/tensorflow/tools/graph_transforms/compare_graphs a.pb b.pb +// +// The return value is 0 if the graphs are equal, 1 if they're different, and -1 +// if there was a problem. + +#include "tensorflow/core/graph/equal_graph_def.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { +namespace { + +int ParseFlagsAndCompareGraphs(int argc, char* argv[]) { + // We need to call this to set up global state for TensorFlow. + port::InitMain(argv[0], &argc, &argv); + + if (argc != 3) { + LOG(ERROR) << "compare_graphs expects two file names as arguments"; + return -1; + } + + GraphDef a; + Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a); + if (!a_load_status.ok()) { + LOG(ERROR) << "Loading graph '" << argv[1] << "' failed with " + << a_load_status.error_message(); + return -1; + } + + GraphDef b; + Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b); + if (!b_load_status.ok()) { + LOG(ERROR) << "Loading graph '" << argv[2] << "' failed with " + << b_load_status.error_message(); + return -1; + } + + string diff; + if (EqualGraphDef(a, b, &diff)) { + std::cout << "Graphs are equal." << std::endl; + return 0; + } else { + std::cout << diff << std::endl; + return 1; + } +} + +} // namespace +} // namespace graph_transforms +} // namespace tensorflow + +int main(int argc, char* argv[]) { + return tensorflow::graph_transforms::ParseFlagsAndCompareGraphs(argc, argv); +} diff --git a/tensorflow/tools/graph_transforms/obsfucate_names.cc b/tensorflow/tools/graph_transforms/obsfucate_names.cc index 00eb0d01b02..c665ed947af 100644 --- a/tensorflow/tools/graph_transforms/obsfucate_names.cc +++ b/tensorflow/tools/graph_transforms/obsfucate_names.cc @@ -40,10 +40,6 @@ Status ObsfucateNames(const GraphDef& input_graph_def, required_nodes.insert(output); } - for (const string& required_node : required_nodes) { - LOG(INFO) << "required_node=" << required_node; - } - const string valid_chars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; const int64 chars_size = valid_chars.size(); diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index 22ed2b669e7..fa089b86efa 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -157,22 +157,8 @@ Status ExtractRangeFromParams(const TransformFuncContext& context, return errors::InvalidArgument("You must pass both ", min_name, " and ", max_name, " into quantize_nodes"); } - std::vector min_strings = context.params.at(min_name); - std::vector max_strings = context.params.at(max_name); - if ((min_strings.size() != 1) || (max_strings.size() != 1)) { - return errors::InvalidArgument("You must pass a single ", min_name, - " and single ", max_name, - " value into " - "quantize_nodes"); - } - if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) { - return errors::InvalidArgument("Couldn't decode ", min_name, - " as a number: ", min_strings[0]); - } - if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) { - return errors::InvalidArgument("Couldn't decode ", max_name, - " as a number: ", max_strings[0]); - } + TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value)); + TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value)); return Status::OK(); } diff --git a/tensorflow/tools/graph_transforms/round_weights.cc b/tensorflow/tools/graph_transforms/round_weights.cc index e73aae0f393..6332876077c 100644 --- a/tensorflow/tools/graph_transforms/round_weights.cc +++ b/tensorflow/tools/graph_transforms/round_weights.cc @@ -33,16 +33,8 @@ namespace graph_transforms { Status RoundWeights(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { - string num_steps_string; - TF_RETURN_IF_ERROR( - GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string)); - int32 num_steps; - if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) { - return errors::InvalidArgument( - "Couldn't interpret the num_steps argument to round_weights as a " - "number:", - num_steps_string); - } + int64 num_steps; + TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps)); TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, {"Const"}, [num_steps](const NodeMatch& match, const std::set& input_nodes, diff --git a/tensorflow/tools/graph_transforms/set_device.cc b/tensorflow/tools/graph_transforms/set_device.cc new file mode 100644 index 00000000000..4e4529f4b6d --- /dev/null +++ b/tensorflow/tools/graph_transforms/set_device.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Sets the device field of ops in the graph. +Status SetDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + string new_device; + TF_RETURN_IF_ERROR(context.GetOneStringParameter("device", "", &new_device)); + bool if_default; + TF_RETURN_IF_ERROR( + context.GetOneBoolParameter("if_default", false, &if_default)); + + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + if (!if_default || (node.device() == "")) { + new_node->set_device(new_device); + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("set_device", SetDevice); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/set_device_test.cc b/tensorflow/tools/graph_transforms/set_device_test.cc new file mode 100644 index 00000000000..fb64e0019d3 --- /dev/null +++ b/tensorflow/tools/graph_transforms/set_device_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status SetDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +namespace { +GraphDef CreateDeviceGraph() { + GraphDef graph_def; + + NodeDef* mul_node1 = graph_def.add_node(); + mul_node1->set_name("mul_node1"); + mul_node1->set_op("Mul"); + mul_node1->set_device("/device:CPU:0"); + mul_node1->add_input("add_node2"); + mul_node1->add_input("add_node3"); + + NodeDef* add_node2 = graph_def.add_node(); + add_node2->set_name("add_node2"); + add_node2->set_op("Add"); + add_node2->add_input("const_node1"); + add_node2->add_input("const_node2"); + add_node2->set_device("/device:GPU:1"); + + NodeDef* add_node3 = graph_def.add_node(); + add_node3->set_name("add_node3"); + add_node3->set_op("Add"); + add_node3->add_input("const_node1"); + add_node3->add_input("const_node3"); + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* add_node4 = graph_def.add_node(); + add_node4->set_name("add_node4"); + add_node4->set_op("Add"); + add_node4->add_input("add_node2"); + add_node4->add_input("add_node3"); + + return graph_def; +} +} // namespace + +TEST(SetDeviceTest, TestSetDevice) { + GraphDef graph_def = CreateDeviceGraph(); + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + context.params.insert(std::pair>( + {"device", {string("/device:CPU:0")}})); + TF_ASSERT_OK(SetDevice(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device()); + EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node2")->device()); + EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node3")->device()); + EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node1")->device()); + EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node2")->device()); + EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node3")->device()); + EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node4")->device()); +} + +TEST(SetDeviceTest, TestSetDeviceIfDefault) { + GraphDef graph_def = CreateDeviceGraph(); + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"mul_node1"}; + context.params.insert(std::pair>( + {"device", {string("/device:GPU:0")}})); + context.params.insert( + std::pair>({"if_default", {string("true")}})); + TF_ASSERT_OK(SetDevice(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device()); + EXPECT_EQ("/device:GPU:1", node_lookup.at("add_node2")->device()); + EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node3")->device()); + EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node1")->device()); + EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node2")->device()); + EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node3")->device()); + EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node4")->device()); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc index 5e71b0bd5cd..c48be92bb99 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.cc +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -129,12 +129,15 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) { string inputs_string = ""; string outputs_string = ""; string transforms_string = ""; + bool output_as_text = false; std::vector flag_list = { Flag("in_graph", &in_graph, "input graph file name"), Flag("out_graph", &out_graph, "output graph file name"), Flag("inputs", &inputs_string, "inputs"), Flag("outputs", &outputs_string, "outputs"), Flag("transforms", &transforms_string, "list of transforms"), + Flag("output_as_text", &output_as_text, + "whether to write the graph in text protobuf format"), }; string usage = Flags::Usage(argv[0], flag_list); usage += "\nTransforms are:\n"; @@ -185,7 +188,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) { } GraphDef graph_def; - Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def); + Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def); if (!load_status.ok()) { LOG(ERROR) << "Loading graph '" << in_graph << "' failed with " << load_status.error_message(); @@ -202,7 +205,12 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) { return -1; } - Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def); + Status save_status; + if (output_as_text) { + save_status = WriteTextProto(Env::Default(), out_graph, graph_def); + } else { + save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def); + } if (!save_status.ok()) { LOG(ERROR) << "Saving graph '" << out_graph << "' failed with " << save_status.error_message(); diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 72bd7f03836..ecdae72c11d 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -583,23 +583,45 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, return Status::OK(); } -int CountParameters(const TransformFuncContext& context, const string& name) { - if (context.params.count(name)) { - return context.params.at(name).size(); +Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) { + string file_data; + Status load_file_status = + ReadFileToString(Env::Default(), file_name, &file_data); + if (!load_file_status.ok()) { + errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")"); + return load_file_status; + } + // Try to load in binary format first, and then try ascii if that fails. + Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def); + if (!load_status.ok()) { + if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) { + load_status = Status::OK(); + } else { + errors::AppendToMessage(&load_status, + " (both text and binary parsing failed for file ", + file_name, ")"); + } + } + return load_status; +} + +int TransformFuncContext::CountParameters(const string& name) const { + if (params.count(name)) { + return params.at(name).size(); } else { return 0; } } -Status GetExactlyOneParameter(const TransformFuncContext& context, - const string& name, const string& default_value, - string* result) { - const int params_count = CountParameters(context, name); +Status TransformFuncContext::GetOneStringParameter(const string& name, + const string& default_value, + string* result) const { + const int params_count = CountParameters(name); if (params_count == 0) { *result = default_value; return Status::OK(); } else if (params_count == 1) { - *result = context.params.at(name).at(0); + *result = params.at(name).at(0); return Status::OK(); } else { return errors::InvalidArgument("Expected a single '", name, @@ -608,5 +630,62 @@ Status GetExactlyOneParameter(const TransformFuncContext& context, } } +Status TransformFuncContext::GetOneIntParameter(const string& name, + int64 default_value, + int64* result) const { + const int params_count = CountParameters(name); + if (params_count == 0) { + *result = default_value; + return Status::OK(); + } + string string_value; + TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); + if (!strings::safe_strto64(StringPiece(string_value), result)) { + return errors::InvalidArgument("Couldn't interpret the ", name, + " argument as a number:", string_value); + } + return Status::OK(); +} + +Status TransformFuncContext::GetOneFloatParameter(const string& name, + float default_value, + float* result) const { + const int params_count = CountParameters(name); + if (params_count == 0) { + *result = default_value; + return Status::OK(); + } + string string_value; + TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); + if (!strings::safe_strtof(string_value.c_str(), result)) { + return errors::InvalidArgument( + "Couldn't interpret the ", name, + " argument as a float number:", string_value); + } + return Status::OK(); +} + +Status TransformFuncContext::GetOneBoolParameter(const string& name, + bool default_value, + bool* result) const { + const int params_count = CountParameters(name); + if (params_count == 0) { + *result = default_value; + return Status::OK(); + } + string string_value; + TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value)); + if (string_value == "true" || string_value == "1") { + *result = true; + } else if (string_value == "false" || string_value == "0") { + *result = false; + } else { + return errors::InvalidArgument("Couldn't interpret the ", name, + " argument as a boolean:", string_value, + " (expected true, false, 0 or 1)"); + } + return Status::OK(); +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index f87d8326ef5..2c98440907c 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -128,6 +128,10 @@ Status IsGraphValid(const GraphDef& graph_def); Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, DataTypeVector* outputs); +// First tries to load the file as a text protobuf, if that fails tries to parse +// it as a binary protobuf, and returns an error if both fail. +Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph); + // This is used to spot particular subgraphs in a larger model. To use it, // create a pattern like: // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); @@ -213,16 +217,33 @@ struct TransformFuncContext { std::vector input_names; std::vector output_names; TransformFuncParameters params; + + // Returns how many occurrences of the given parameter are present. + int CountParameters(const string& name) const; + + // Gets a single instance of a parameter, using a default if it's not present. + Status GetOneStringParameter(const string& name, const string& default_value, + string* result) const; + + // Gets a single occurrence of a parameter as an integer, falling back to a + // default if it isn't present and returning an error if it isn't convertible + // to a number. + Status GetOneIntParameter(const string& name, int64 default_value, + int64* result) const; + + // Gets a single occurrence of a parameter as a floating point number, falling + // back to a default if it isn't present and returning an error if it isn't + // convertible to a number. + Status GetOneFloatParameter(const string& name, float default_value, + float* result) const; + + // Gets a single occurrence of a parameter as a boolean, falling back to a + // default if it isn't present and returning an error if it's not one of + // "true", "1", "false", or "0". + Status GetOneBoolParameter(const string& name, bool default_value, + bool* result) const; }; -// Returns how many occurrences of the given parameter are present. -int CountParameters(const TransformFuncContext& context, const string& name); - -// Gets a simple occurrence of a parameter, using a default if it isn't present. -Status GetExactlyOneParameter(const TransformFuncContext& context, - const string& name, const string& default_value, - string* result); - // This is the function API for all graph transformations, taking an input // GraphDef and other arguments, and returning a transformed GraphDef. typedef std::function(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + const string text_file = + io::JoinPath(testing::TmpDir(), "text_graph.pbtxt"); + TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def)); + + const string binary_file = + io::JoinPath(testing::TmpDir(), "binary_graph.pb"); + TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def)); + + const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb"); + TF_ASSERT_OK( + WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto...")); + + GraphDef text_graph_def; + TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def)); + string text_diff; + EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff)) + << text_diff; + + GraphDef binary_graph_def; + TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def)); + string binary_diff; + EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff)) + << binary_diff; + + GraphDef no_graph_def; + EXPECT_FALSE( + LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def) + .ok()); + + GraphDef bogus_graph_def; + EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok()); + } }; TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } @@ -1012,8 +1123,22 @@ TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); } TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); } -TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) { - TestGetExactlyOneParameter(); +TEST_F(TransformUtilsTest, TestGetOneStringParameter) { + TestGetOneStringParameter(); +} + +TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); } + +TEST_F(TransformUtilsTest, TestGetOneFloatParameter) { + TestGetOneFloatParameter(); +} + +TEST_F(TransformUtilsTest, TestGetOneBoolParameter) { + TestGetOneBoolParameter(); +} + +TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) { + TestLoadTextOrBinaryGraphFile(); } } // namespace graph_transforms