Add new transforms and compare_graph tool
Change: 144024968
This commit is contained in:
parent
bfa0fa079d
commit
80679a6a74
@ -43,6 +43,7 @@ tf_cc_test(
|
|||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -55,6 +56,7 @@ tf_cc_test(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "transforms_lib",
|
name = "transforms_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"add_default_attributes.cc",
|
||||||
"fold_batch_norms.cc",
|
"fold_batch_norms.cc",
|
||||||
"fold_constants_lib.cc",
|
"fold_constants_lib.cc",
|
||||||
"fold_old_batch_norms.cc",
|
"fold_old_batch_norms.cc",
|
||||||
@ -68,6 +70,7 @@ cc_library(
|
|||||||
"rename_attribute.cc",
|
"rename_attribute.cc",
|
||||||
"rename_op.cc",
|
"rename_op.cc",
|
||||||
"round_weights.cc",
|
"round_weights.cc",
|
||||||
|
"set_device.cc",
|
||||||
"sort_by_execution_order.cc",
|
"sort_by_execution_order.cc",
|
||||||
"strip_unused_nodes.cc",
|
"strip_unused_nodes.cc",
|
||||||
],
|
],
|
||||||
@ -80,6 +83,7 @@ cc_library(
|
|||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
@ -93,6 +97,7 @@ tf_cc_test(
|
|||||||
name = "transforms_test",
|
name = "transforms_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"add_default_attributes_test.cc",
|
||||||
"fold_batch_norms_test.cc",
|
"fold_batch_norms_test.cc",
|
||||||
"fold_constants_test.cc",
|
"fold_constants_test.cc",
|
||||||
"fold_old_batch_norms_test.cc",
|
"fold_old_batch_norms_test.cc",
|
||||||
@ -106,6 +111,7 @@ tf_cc_test(
|
|||||||
"rename_attribute_test.cc",
|
"rename_attribute_test.cc",
|
||||||
"rename_op_test.cc",
|
"rename_op_test.cc",
|
||||||
"round_weights_test.cc",
|
"round_weights_test.cc",
|
||||||
|
"set_device_test.cc",
|
||||||
"sort_by_execution_order_test.cc",
|
"sort_by_execution_order_test.cc",
|
||||||
"strip_unused_nodes_test.cc",
|
"strip_unused_nodes_test.cc",
|
||||||
],
|
],
|
||||||
@ -209,3 +215,18 @@ cc_binary(
|
|||||||
":summarize_graph_main_lib",
|
":summarize_graph_main_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "compare_graphs",
|
||||||
|
srcs = ["compare_graphs.cc"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
linkstatic = 1,
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":transform_utils",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
* [Shrinking File Size](#shrinking-file-size)
|
* [Shrinking File Size](#shrinking-file-size)
|
||||||
* [Eight-bit Calculations](#eight-bit-calculations)
|
* [Eight-bit Calculations](#eight-bit-calculations)
|
||||||
* [Transform Reference](#transform-reference)
|
* [Transform Reference](#transform-reference)
|
||||||
|
* [add_default_attributes](#add_default_attributes)
|
||||||
* [fold_batch_norms](#fold_batch_norms)
|
* [fold_batch_norms](#fold_batch_norms)
|
||||||
* [fold_constants](#fold_constants)
|
* [fold_constants](#fold_constants)
|
||||||
* [fold_old_batch_norms](#fold_old_batch_norms)
|
* [fold_old_batch_norms](#fold_old_batch_norms)
|
||||||
@ -26,6 +27,7 @@
|
|||||||
* [rename_attribute](#rename_attribute)
|
* [rename_attribute](#rename_attribute)
|
||||||
* [rename_op](#rename_op)
|
* [rename_op](#rename_op)
|
||||||
* [round_weights](#round_weights)
|
* [round_weights](#round_weights)
|
||||||
|
* [set_device](#set_device)
|
||||||
* [sort_by_execution_order](#sort_by_execution_order)
|
* [sort_by_execution_order](#sort_by_execution_order)
|
||||||
* [strip_unused_nodes](#strip_unused_nodes)
|
* [strip_unused_nodes](#strip_unused_nodes)
|
||||||
* [Writing Your Own Transforms](#writing-your-own-transforms)
|
* [Writing Your Own Transforms](#writing-your-own-transforms)
|
||||||
@ -318,6 +320,18 @@ logged and the transform skipped. This is especially useful for optional
|
|||||||
transforms where version errors or other unimportant problems may trigger an
|
transforms where version errors or other unimportant problems may trigger an
|
||||||
error.
|
error.
|
||||||
|
|
||||||
|
### add_default_attributes
|
||||||
|
|
||||||
|
Args: None
|
||||||
|
|
||||||
|
When attributes are added to ops in new versions of TensorFlow, they often have
|
||||||
|
defaults to ensure backwards compatible behavior with their original versions.
|
||||||
|
These defaults usually get added when the graph is loaded by the runtime, but if
|
||||||
|
your model is going to be processed outside of the main TensorFlow framework it
|
||||||
|
can be useful to run this update process as a transform. This process finds any
|
||||||
|
op attributes that are defined in the current TensorFlow list of ops but not
|
||||||
|
within the saved model, and sets them to the defined default for that attribute.
|
||||||
|
|
||||||
### fold_batch_norms
|
### fold_batch_norms
|
||||||
|
|
||||||
Args: None
|
Args: None
|
||||||
@ -516,6 +530,22 @@ between the largest and smallest values present. This is useful when you'll be
|
|||||||
deploying on mobile, and you want a model that will compress effectively. See
|
deploying on mobile, and you want a model that will compress effectively. See
|
||||||
[shrinking file size](#shrinking-file-size) for more details.
|
[shrinking file size](#shrinking-file-size) for more details.
|
||||||
|
|
||||||
|
### set_device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
* device: What device to assign to ops.
|
||||||
|
* if_default: If this is true, only assign to ops with empty existing devices.
|
||||||
|
|
||||||
|
Updates nodes to use the specified device. A device is a way to tell the code
|
||||||
|
that executes the graph which piece of hardware it should run particular nodes
|
||||||
|
on. The right assignment to use may change between training and deployment, so
|
||||||
|
this transform (and [remove_device](#remove_device)) provide a way of updating
|
||||||
|
the placement. If the `is_default` parameter is set, then only ops that don't
|
||||||
|
have a device assigned already will be updated. This is mostly useful for
|
||||||
|
preprocessing of graphs for other stages that expect all ops to have an explicit
|
||||||
|
device assigned.
|
||||||
|
|
||||||
### sort_by_execution_order
|
### sort_by_execution_order
|
||||||
|
|
||||||
Args: None\
|
Args: None\
|
||||||
@ -844,23 +874,13 @@ Here's an example of how [round_weights](#round_weights) reads its `num_steps`
|
|||||||
parameter:
|
parameter:
|
||||||
|
|
||||||
```C++
|
```C++
|
||||||
string num_steps_string;
|
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
|
|
||||||
int32 num_steps;
|
|
||||||
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Couldn't interpret the num_steps argument to round_weights as a "
|
|
||||||
"number:",
|
|
||||||
num_steps_string);
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Something to notice here is that you have to convert the string to an integer,
|
If the conversion fails or the parameter occurs more than once the helper
|
||||||
and if the conversion fails you need to raise a meaningful error through the
|
function will raise a meaningful error through the status result of the
|
||||||
status result of the transform. Also, we're using a helper function which raises
|
transform. If the parameter isn't specified at all then the default will be
|
||||||
an error if the parameter is present multiple times, and uses a default if the
|
used.
|
||||||
user hasn't specified it.
|
|
||||||
|
|
||||||
### Function Libraries
|
### Function Libraries
|
||||||
|
|
||||||
|
40
tensorflow/tools/graph_transforms/add_default_attributes.cc
Normal file
40
tensorflow/tools/graph_transforms/add_default_attributes.cc
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Sets any parameters not specified in a node to their defaults.
|
||||||
|
Status AddDefaultAttributes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
// Find all of the ops that are currently defined.
|
||||||
|
std::unique_ptr<FunctionLibraryDefinition> flib_def(
|
||||||
|
new FunctionLibraryDefinition(OpRegistry::Global(),
|
||||||
|
input_graph_def.library()));
|
||||||
|
// Works in-place, so copy over the original graph.
|
||||||
|
*output_graph_def = input_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(output_graph_def, *flib_def, 0));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("add_default_attributes", AddDefaultAttributes);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status AddDefaultAttributes(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class AddDefaultAttributesTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestAddDefaultAttributes() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* lrn_node1 = graph_def.add_node();
|
||||||
|
lrn_node1->set_name("lrn_node1");
|
||||||
|
lrn_node1->set_op("LRN");
|
||||||
|
|
||||||
|
NodeDef* lrn_node2 = graph_def.add_node();
|
||||||
|
lrn_node2->set_name("lrn_node2");
|
||||||
|
lrn_node2->set_op("LRN");
|
||||||
|
SetNodeAttr("depth_radius", 7, lrn_node2);
|
||||||
|
SetNodeAttr("bias", 2.0f, lrn_node2);
|
||||||
|
SetNodeAttr("alpha", 2.0f, lrn_node2);
|
||||||
|
SetNodeAttr("beta", 1.0f, lrn_node2);
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TF_ASSERT_OK(AddDefaultAttributes(graph_def, {}, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> nodes;
|
||||||
|
MapNamesToNodes(result, &nodes);
|
||||||
|
EXPECT_EQ(5, nodes.at("lrn_node1")->attr().at("depth_radius").i());
|
||||||
|
EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("bias").f(), 1e-5f);
|
||||||
|
EXPECT_NEAR(1.0f, nodes.at("lrn_node1")->attr().at("alpha").f(), 1e-5f);
|
||||||
|
EXPECT_NEAR(0.5f, nodes.at("lrn_node1")->attr().at("beta").f(), 1e-5f);
|
||||||
|
EXPECT_EQ(7, nodes.at("lrn_node2")->attr().at("depth_radius").i());
|
||||||
|
EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("bias").f(), 1e-5f);
|
||||||
|
EXPECT_NEAR(2.0f, nodes.at("lrn_node2")->attr().at("alpha").f(), 1e-5f);
|
||||||
|
EXPECT_NEAR(1.0f, nodes.at("lrn_node2")->attr().at("beta").f(), 1e-5f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(AddDefaultAttributesTest, TestAddDefaultAttributes) {
|
||||||
|
TestAddDefaultAttributes();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
79
tensorflow/tools/graph_transforms/compare_graphs.cc
Normal file
79
tensorflow/tools/graph_transforms/compare_graphs.cc
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Compares two TensorFlow graphs to see if their meaning is the same. This is a
|
||||||
|
// semantic comparison that's intended to show whether the graphs should produce
|
||||||
|
// the same results, and so ignores details like version numbers or node
|
||||||
|
// ordering that don't affect the output. To use it, run something like this:
|
||||||
|
//
|
||||||
|
// bazel build tensorflow/tools/graph_transforms:compare_graphs
|
||||||
|
// bazel-bin/tensorflow/tools/graph_transforms/compare_graphs a.pb b.pb
|
||||||
|
//
|
||||||
|
// The return value is 0 if the graphs are equal, 1 if they're different, and -1
|
||||||
|
// if there was a problem.
|
||||||
|
|
||||||
|
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
int ParseFlagsAndCompareGraphs(int argc, char* argv[]) {
|
||||||
|
// We need to call this to set up global state for TensorFlow.
|
||||||
|
port::InitMain(argv[0], &argc, &argv);
|
||||||
|
|
||||||
|
if (argc != 3) {
|
||||||
|
LOG(ERROR) << "compare_graphs expects two file names as arguments";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef a;
|
||||||
|
Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a);
|
||||||
|
if (!a_load_status.ok()) {
|
||||||
|
LOG(ERROR) << "Loading graph '" << argv[1] << "' failed with "
|
||||||
|
<< a_load_status.error_message();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDef b;
|
||||||
|
Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b);
|
||||||
|
if (!b_load_status.ok()) {
|
||||||
|
LOG(ERROR) << "Loading graph '" << argv[2] << "' failed with "
|
||||||
|
<< b_load_status.error_message();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
string diff;
|
||||||
|
if (EqualGraphDef(a, b, &diff)) {
|
||||||
|
std::cout << "Graphs are equal." << std::endl;
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
std::cout << diff << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
return tensorflow::graph_transforms::ParseFlagsAndCompareGraphs(argc, argv);
|
||||||
|
}
|
@ -40,10 +40,6 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
|
|||||||
required_nodes.insert(output);
|
required_nodes.insert(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const string& required_node : required_nodes) {
|
|
||||||
LOG(INFO) << "required_node=" << required_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
const string valid_chars =
|
const string valid_chars =
|
||||||
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||||
const int64 chars_size = valid_chars.size();
|
const int64 chars_size = valid_chars.size();
|
||||||
|
@ -157,22 +157,8 @@ Status ExtractRangeFromParams(const TransformFuncContext& context,
|
|||||||
return errors::InvalidArgument("You must pass both ", min_name, " and ",
|
return errors::InvalidArgument("You must pass both ", min_name, " and ",
|
||||||
max_name, " into quantize_nodes");
|
max_name, " into quantize_nodes");
|
||||||
}
|
}
|
||||||
std::vector<string> min_strings = context.params.at(min_name);
|
TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
|
||||||
std::vector<string> max_strings = context.params.at(max_name);
|
TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
|
||||||
if ((min_strings.size() != 1) || (max_strings.size() != 1)) {
|
|
||||||
return errors::InvalidArgument("You must pass a single ", min_name,
|
|
||||||
" and single ", max_name,
|
|
||||||
" value into "
|
|
||||||
"quantize_nodes");
|
|
||||||
}
|
|
||||||
if (!strings::safe_strtof(min_strings[0].c_str(), min_value)) {
|
|
||||||
return errors::InvalidArgument("Couldn't decode ", min_name,
|
|
||||||
" as a number: ", min_strings[0]);
|
|
||||||
}
|
|
||||||
if (!strings::safe_strtof(max_strings[0].c_str(), max_value)) {
|
|
||||||
return errors::InvalidArgument("Couldn't decode ", max_name,
|
|
||||||
" as a number: ", max_strings[0]);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,16 +33,8 @@ namespace graph_transforms {
|
|||||||
Status RoundWeights(const GraphDef& input_graph_def,
|
Status RoundWeights(const GraphDef& input_graph_def,
|
||||||
const TransformFuncContext& context,
|
const TransformFuncContext& context,
|
||||||
GraphDef* output_graph_def) {
|
GraphDef* output_graph_def) {
|
||||||
string num_steps_string;
|
int64 num_steps;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(context.GetOneIntParameter("num_steps", 256, &num_steps));
|
||||||
GetExactlyOneParameter(context, "num_steps", "256", &num_steps_string));
|
|
||||||
int32 num_steps;
|
|
||||||
if (!strings::safe_strto32(StringPiece(num_steps_string), &num_steps)) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Couldn't interpret the num_steps argument to round_weights as a "
|
|
||||||
"number:",
|
|
||||||
num_steps_string);
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
input_graph_def, {"Const"},
|
input_graph_def, {"Const"},
|
||||||
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
|
[num_steps](const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
46
tensorflow/tools/graph_transforms/set_device.cc
Normal file
46
tensorflow/tools/graph_transforms/set_device.cc
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Sets the device field of ops in the graph.
|
||||||
|
Status SetDevice(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
string new_device;
|
||||||
|
TF_RETURN_IF_ERROR(context.GetOneStringParameter("device", "", &new_device));
|
||||||
|
bool if_default;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context.GetOneBoolParameter("if_default", false, &if_default));
|
||||||
|
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& node : input_graph_def.node()) {
|
||||||
|
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||||
|
new_node->CopyFrom(node);
|
||||||
|
if (!if_default || (node.device() == "")) {
|
||||||
|
new_node->set_device(new_device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("set_device", SetDevice);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
127
tensorflow/tools/graph_transforms/set_device_test.cc
Normal file
127
tensorflow/tools/graph_transforms/set_device_test.cc
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status SetDevice(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
GraphDef CreateDeviceGraph() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* mul_node1 = graph_def.add_node();
|
||||||
|
mul_node1->set_name("mul_node1");
|
||||||
|
mul_node1->set_op("Mul");
|
||||||
|
mul_node1->set_device("/device:CPU:0");
|
||||||
|
mul_node1->add_input("add_node2");
|
||||||
|
mul_node1->add_input("add_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node2 = graph_def.add_node();
|
||||||
|
add_node2->set_name("add_node2");
|
||||||
|
add_node2->set_op("Add");
|
||||||
|
add_node2->add_input("const_node1");
|
||||||
|
add_node2->add_input("const_node2");
|
||||||
|
add_node2->set_device("/device:GPU:1");
|
||||||
|
|
||||||
|
NodeDef* add_node3 = graph_def.add_node();
|
||||||
|
add_node3->set_name("add_node3");
|
||||||
|
add_node3->set_op("Add");
|
||||||
|
add_node3->add_input("const_node1");
|
||||||
|
add_node3->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* add_node4 = graph_def.add_node();
|
||||||
|
add_node4->set_name("add_node4");
|
||||||
|
add_node4->set_op("Add");
|
||||||
|
add_node4->add_input("add_node2");
|
||||||
|
add_node4->add_input("add_node3");
|
||||||
|
|
||||||
|
return graph_def;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(SetDeviceTest, TestSetDevice) {
|
||||||
|
GraphDef graph_def = CreateDeviceGraph();
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"mul_node1"};
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"device", {string("/device:CPU:0")}}));
|
||||||
|
TF_ASSERT_OK(SetDevice(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node2")->device());
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node3")->device());
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node1")->device());
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node2")->device());
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("const_node3")->device());
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("add_node4")->device());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SetDeviceTest, TestSetDeviceIfDefault) {
|
||||||
|
GraphDef graph_def = CreateDeviceGraph();
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"mul_node1"};
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"device", {string("/device:GPU:0")}}));
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"if_default", {string("true")}}));
|
||||||
|
TF_ASSERT_OK(SetDevice(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ("/device:CPU:0", node_lookup.at("mul_node1")->device());
|
||||||
|
EXPECT_EQ("/device:GPU:1", node_lookup.at("add_node2")->device());
|
||||||
|
EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node3")->device());
|
||||||
|
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node1")->device());
|
||||||
|
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node2")->device());
|
||||||
|
EXPECT_EQ("/device:GPU:0", node_lookup.at("const_node3")->device());
|
||||||
|
EXPECT_EQ("/device:GPU:0", node_lookup.at("add_node4")->device());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
@ -129,12 +129,15 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
|||||||
string inputs_string = "";
|
string inputs_string = "";
|
||||||
string outputs_string = "";
|
string outputs_string = "";
|
||||||
string transforms_string = "";
|
string transforms_string = "";
|
||||||
|
bool output_as_text = false;
|
||||||
std::vector<Flag> flag_list = {
|
std::vector<Flag> flag_list = {
|
||||||
Flag("in_graph", &in_graph, "input graph file name"),
|
Flag("in_graph", &in_graph, "input graph file name"),
|
||||||
Flag("out_graph", &out_graph, "output graph file name"),
|
Flag("out_graph", &out_graph, "output graph file name"),
|
||||||
Flag("inputs", &inputs_string, "inputs"),
|
Flag("inputs", &inputs_string, "inputs"),
|
||||||
Flag("outputs", &outputs_string, "outputs"),
|
Flag("outputs", &outputs_string, "outputs"),
|
||||||
Flag("transforms", &transforms_string, "list of transforms"),
|
Flag("transforms", &transforms_string, "list of transforms"),
|
||||||
|
Flag("output_as_text", &output_as_text,
|
||||||
|
"whether to write the graph in text protobuf format"),
|
||||||
};
|
};
|
||||||
string usage = Flags::Usage(argv[0], flag_list);
|
string usage = Flags::Usage(argv[0], flag_list);
|
||||||
usage += "\nTransforms are:\n";
|
usage += "\nTransforms are:\n";
|
||||||
@ -185,7 +188,7 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
Status load_status = ReadBinaryProto(Env::Default(), in_graph, &graph_def);
|
Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
|
||||||
if (!load_status.ok()) {
|
if (!load_status.ok()) {
|
||||||
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
|
LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
|
||||||
<< load_status.error_message();
|
<< load_status.error_message();
|
||||||
@ -202,7 +205,12 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
|
Status save_status;
|
||||||
|
if (output_as_text) {
|
||||||
|
save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
|
||||||
|
} else {
|
||||||
|
save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
|
||||||
|
}
|
||||||
if (!save_status.ok()) {
|
if (!save_status.ok()) {
|
||||||
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
|
LOG(ERROR) << "Saving graph '" << out_graph << "' failed with "
|
||||||
<< save_status.error_message();
|
<< save_status.error_message();
|
||||||
|
@ -583,23 +583,45 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int CountParameters(const TransformFuncContext& context, const string& name) {
|
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
|
||||||
if (context.params.count(name)) {
|
string file_data;
|
||||||
return context.params.at(name).size();
|
Status load_file_status =
|
||||||
|
ReadFileToString(Env::Default(), file_name, &file_data);
|
||||||
|
if (!load_file_status.ok()) {
|
||||||
|
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
|
||||||
|
return load_file_status;
|
||||||
|
}
|
||||||
|
// Try to load in binary format first, and then try ascii if that fails.
|
||||||
|
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
|
||||||
|
if (!load_status.ok()) {
|
||||||
|
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
|
||||||
|
load_status = Status::OK();
|
||||||
|
} else {
|
||||||
|
errors::AppendToMessage(&load_status,
|
||||||
|
" (both text and binary parsing failed for file ",
|
||||||
|
file_name, ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return load_status;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TransformFuncContext::CountParameters(const string& name) const {
|
||||||
|
if (params.count(name)) {
|
||||||
|
return params.at(name).size();
|
||||||
} else {
|
} else {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetExactlyOneParameter(const TransformFuncContext& context,
|
Status TransformFuncContext::GetOneStringParameter(const string& name,
|
||||||
const string& name, const string& default_value,
|
const string& default_value,
|
||||||
string* result) {
|
string* result) const {
|
||||||
const int params_count = CountParameters(context, name);
|
const int params_count = CountParameters(name);
|
||||||
if (params_count == 0) {
|
if (params_count == 0) {
|
||||||
*result = default_value;
|
*result = default_value;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else if (params_count == 1) {
|
} else if (params_count == 1) {
|
||||||
*result = context.params.at(name).at(0);
|
*result = params.at(name).at(0);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return errors::InvalidArgument("Expected a single '", name,
|
return errors::InvalidArgument("Expected a single '", name,
|
||||||
@ -608,5 +630,62 @@ Status GetExactlyOneParameter(const TransformFuncContext& context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TransformFuncContext::GetOneIntParameter(const string& name,
|
||||||
|
int64 default_value,
|
||||||
|
int64* result) const {
|
||||||
|
const int params_count = CountParameters(name);
|
||||||
|
if (params_count == 0) {
|
||||||
|
*result = default_value;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
string string_value;
|
||||||
|
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||||
|
if (!strings::safe_strto64(StringPiece(string_value), result)) {
|
||||||
|
return errors::InvalidArgument("Couldn't interpret the ", name,
|
||||||
|
" argument as a number:", string_value);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TransformFuncContext::GetOneFloatParameter(const string& name,
|
||||||
|
float default_value,
|
||||||
|
float* result) const {
|
||||||
|
const int params_count = CountParameters(name);
|
||||||
|
if (params_count == 0) {
|
||||||
|
*result = default_value;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
string string_value;
|
||||||
|
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||||
|
if (!strings::safe_strtof(string_value.c_str(), result)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Couldn't interpret the ", name,
|
||||||
|
" argument as a float number:", string_value);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TransformFuncContext::GetOneBoolParameter(const string& name,
|
||||||
|
bool default_value,
|
||||||
|
bool* result) const {
|
||||||
|
const int params_count = CountParameters(name);
|
||||||
|
if (params_count == 0) {
|
||||||
|
*result = default_value;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
string string_value;
|
||||||
|
TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
|
||||||
|
if (string_value == "true" || string_value == "1") {
|
||||||
|
*result = true;
|
||||||
|
} else if (string_value == "false" || string_value == "0") {
|
||||||
|
*result = false;
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Couldn't interpret the ", name,
|
||||||
|
" argument as a boolean:", string_value,
|
||||||
|
" (expected true, false, 0 or 1)");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -128,6 +128,10 @@ Status IsGraphValid(const GraphDef& graph_def);
|
|||||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||||
DataTypeVector* outputs);
|
DataTypeVector* outputs);
|
||||||
|
|
||||||
|
// First tries to load the file as a text protobuf, if that fails tries to parse
|
||||||
|
// it as a binary protobuf, and returns an error if both fail.
|
||||||
|
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
|
||||||
|
|
||||||
// This is used to spot particular subgraphs in a larger model. To use it,
|
// This is used to spot particular subgraphs in a larger model. To use it,
|
||||||
// create a pattern like:
|
// create a pattern like:
|
||||||
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
||||||
@ -213,15 +217,32 @@ struct TransformFuncContext {
|
|||||||
std::vector<string> input_names;
|
std::vector<string> input_names;
|
||||||
std::vector<string> output_names;
|
std::vector<string> output_names;
|
||||||
TransformFuncParameters params;
|
TransformFuncParameters params;
|
||||||
};
|
|
||||||
|
|
||||||
// Returns how many occurrences of the given parameter are present.
|
// Returns how many occurrences of the given parameter are present.
|
||||||
int CountParameters(const TransformFuncContext& context, const string& name);
|
int CountParameters(const string& name) const;
|
||||||
|
|
||||||
// Gets a simple occurrence of a parameter, using a default if it isn't present.
|
// Gets a single instance of a parameter, using a default if it's not present.
|
||||||
Status GetExactlyOneParameter(const TransformFuncContext& context,
|
Status GetOneStringParameter(const string& name, const string& default_value,
|
||||||
const string& name, const string& default_value,
|
string* result) const;
|
||||||
string* result);
|
|
||||||
|
// Gets a single occurrence of a parameter as an integer, falling back to a
|
||||||
|
// default if it isn't present and returning an error if it isn't convertible
|
||||||
|
// to a number.
|
||||||
|
Status GetOneIntParameter(const string& name, int64 default_value,
|
||||||
|
int64* result) const;
|
||||||
|
|
||||||
|
// Gets a single occurrence of a parameter as a floating point number, falling
|
||||||
|
// back to a default if it isn't present and returning an error if it isn't
|
||||||
|
// convertible to a number.
|
||||||
|
Status GetOneFloatParameter(const string& name, float default_value,
|
||||||
|
float* result) const;
|
||||||
|
|
||||||
|
// Gets a single occurrence of a parameter as a boolean, falling back to a
|
||||||
|
// default if it isn't present and returning an error if it's not one of
|
||||||
|
// "true", "1", "false", or "0".
|
||||||
|
Status GetOneBoolParameter(const string& name, bool default_value,
|
||||||
|
bool* result) const;
|
||||||
|
};
|
||||||
|
|
||||||
// This is the function API for all graph transformations, taking an input
|
// This is the function API for all graph transformations, taking an input
|
||||||
// GraphDef and other arguments, and returning a transformed GraphDef.
|
// GraphDef and other arguments, and returning a transformed GraphDef.
|
||||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/nn_ops.h"
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
#include "tensorflow/core/public/session.h"
|
#include "tensorflow/core/public/session.h"
|
||||||
@ -924,22 +926,131 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
TransformFuncContext context;
|
TransformFuncContext context;
|
||||||
context.params.insert({"foo", {"a", "b"}});
|
context.params.insert({"foo", {"a", "b"}});
|
||||||
context.params.insert({"bar", {"c"}});
|
context.params.insert({"bar", {"c"}});
|
||||||
EXPECT_EQ(2, CountParameters(context, "foo"));
|
EXPECT_EQ(2, context.CountParameters("foo"));
|
||||||
EXPECT_EQ(1, CountParameters(context, "bar"));
|
EXPECT_EQ(1, context.CountParameters("bar"));
|
||||||
EXPECT_EQ(0, CountParameters(context, "not_present"));
|
EXPECT_EQ(0, context.CountParameters("not_present"));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestGetExactlyOneParameter() {
|
void TestGetOneStringParameter() {
|
||||||
TransformFuncContext context;
|
TransformFuncContext context;
|
||||||
context.params.insert({"foo", {"a", "b"}});
|
context.params.insert({"foo", {"a", "b"}});
|
||||||
context.params.insert({"bar", {"c"}});
|
context.params.insert({"bar", {"c"}});
|
||||||
string value;
|
string value;
|
||||||
TF_EXPECT_OK(GetExactlyOneParameter(context, "bar", "d", &value));
|
TF_EXPECT_OK(context.GetOneStringParameter("bar", "d", &value));
|
||||||
EXPECT_EQ("c", value);
|
EXPECT_EQ("c", value);
|
||||||
EXPECT_FALSE(GetExactlyOneParameter(context, "foo", "d", &value).ok());
|
EXPECT_FALSE(context.GetOneStringParameter("foo", "d", &value).ok());
|
||||||
TF_EXPECT_OK(GetExactlyOneParameter(context, "not_present", "d", &value));
|
TF_EXPECT_OK(context.GetOneStringParameter("not_present", "d", &value));
|
||||||
EXPECT_EQ("d", value);
|
EXPECT_EQ("d", value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestGetOneIntParameter() {
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.params.insert({"foo", {"10", "20"}});
|
||||||
|
context.params.insert({"bar", {"-23"}});
|
||||||
|
context.params.insert({"not_a_number", {"not_numerical"}});
|
||||||
|
context.params.insert({"float", {"-23.232323"}});
|
||||||
|
int64 value;
|
||||||
|
TF_EXPECT_OK(context.GetOneIntParameter("bar", 0, &value));
|
||||||
|
EXPECT_EQ(-23, value);
|
||||||
|
EXPECT_FALSE(context.GetOneIntParameter("foo", 0, &value).ok());
|
||||||
|
TF_EXPECT_OK(context.GetOneIntParameter("not_present", 10, &value));
|
||||||
|
EXPECT_EQ(10, value);
|
||||||
|
EXPECT_FALSE(context.GetOneIntParameter("not_a_number", 0, &value).ok());
|
||||||
|
EXPECT_FALSE(context.GetOneIntParameter("float", 0, &value).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetOneFloatParameter() {
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.params.insert({"foo", {"10.0", "20.0"}});
|
||||||
|
context.params.insert({"bar", {"-23.2323"}});
|
||||||
|
context.params.insert({"not_a_number", {"not_numerical"}});
|
||||||
|
float value;
|
||||||
|
TF_EXPECT_OK(context.GetOneFloatParameter("bar", 0, &value));
|
||||||
|
EXPECT_NEAR(-23.2323f, value, 1e-5f);
|
||||||
|
EXPECT_FALSE(context.GetOneFloatParameter("foo", 0, &value).ok());
|
||||||
|
TF_EXPECT_OK(context.GetOneFloatParameter("not_present", 10.5f, &value));
|
||||||
|
EXPECT_NEAR(10.5f, value, 1e-5f);
|
||||||
|
EXPECT_FALSE(context.GetOneFloatParameter("not_a_number", 0, &value).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestGetOneBoolParameter() {
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.params.insert({"foo", {"true", "false"}});
|
||||||
|
context.params.insert({"true", {"true"}});
|
||||||
|
context.params.insert({"false", {"false"}});
|
||||||
|
context.params.insert({"one", {"1"}});
|
||||||
|
context.params.insert({"zero", {"0"}});
|
||||||
|
context.params.insert({"not_a_bool", {"not_boolean"}});
|
||||||
|
|
||||||
|
bool value;
|
||||||
|
EXPECT_FALSE(context.GetOneBoolParameter("foo", 0, &value).ok());
|
||||||
|
|
||||||
|
value = false;
|
||||||
|
TF_EXPECT_OK(context.GetOneBoolParameter("true", false, &value));
|
||||||
|
EXPECT_TRUE(value);
|
||||||
|
|
||||||
|
value = true;
|
||||||
|
TF_EXPECT_OK(context.GetOneBoolParameter("false", true, &value));
|
||||||
|
EXPECT_FALSE(value);
|
||||||
|
|
||||||
|
value = false;
|
||||||
|
TF_EXPECT_OK(context.GetOneBoolParameter("one", false, &value));
|
||||||
|
EXPECT_TRUE(value);
|
||||||
|
|
||||||
|
value = true;
|
||||||
|
TF_EXPECT_OK(context.GetOneBoolParameter("zero", true, &value));
|
||||||
|
EXPECT_FALSE(value);
|
||||||
|
|
||||||
|
EXPECT_FALSE(context.GetOneBoolParameter("not_a_bool", false, &value).ok());
|
||||||
|
|
||||||
|
value = false;
|
||||||
|
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
|
||||||
|
EXPECT_TRUE(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestLoadTextOrBinaryGraphFile() {
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
const string text_file =
|
||||||
|
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
|
||||||
|
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
|
||||||
|
|
||||||
|
const string binary_file =
|
||||||
|
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
|
||||||
|
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
|
||||||
|
|
||||||
|
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
|
||||||
|
|
||||||
|
GraphDef text_graph_def;
|
||||||
|
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
|
||||||
|
string text_diff;
|
||||||
|
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
|
||||||
|
<< text_diff;
|
||||||
|
|
||||||
|
GraphDef binary_graph_def;
|
||||||
|
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
|
||||||
|
string binary_diff;
|
||||||
|
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
|
||||||
|
<< binary_diff;
|
||||||
|
|
||||||
|
GraphDef no_graph_def;
|
||||||
|
EXPECT_FALSE(
|
||||||
|
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
|
||||||
|
.ok());
|
||||||
|
|
||||||
|
GraphDef bogus_graph_def;
|
||||||
|
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
||||||
@ -1012,8 +1123,22 @@ TEST_F(TransformUtilsTest, TestHashNodeDef) { TestHashNodeDef(); }
|
|||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
|
TEST_F(TransformUtilsTest, TestCountParameters) { TestCountParameters(); }
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestGetExactlyOneParameter) {
|
TEST_F(TransformUtilsTest, TestGetOneStringParameter) {
|
||||||
TestGetExactlyOneParameter();
|
TestGetOneStringParameter();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetOneIntParameter) { TestGetOneIntParameter(); }
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetOneFloatParameter) {
|
||||||
|
TestGetOneFloatParameter();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
|
||||||
|
TestGetOneBoolParameter();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
|
||||||
|
TestLoadTextOrBinaryGraphFile();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
|
Loading…
Reference in New Issue
Block a user