Modify dependency structure of transform_utils to allow use within Grappler.

PiperOrigin-RevId: 162998272
This commit is contained in:
Suharsh Sivakumar 2017-07-24 15:32:54 -07:00 committed by TensorFlower Gardener
parent a976910e83
commit 6a93f10b81
14 changed files with 210 additions and 85 deletions

View File

@ -30,6 +30,6 @@ cc_binary(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/core/kernels/hexagon:graph_transferer", "//tensorflow/core/kernels/hexagon:graph_transferer",
"//tensorflow/tools/graph_transforms:transform_utils", "//tensorflow/tools/graph_transforms:file_utils",
], ],
) )

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h" #include "tensorflow/tools/graph_transforms/file_utils.h"
namespace tensorflow { namespace tensorflow {

View File

@ -5202,6 +5202,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",

View File

@ -111,6 +111,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",

View File

@ -26,14 +26,12 @@ cc_library(
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
], ],
) )
@ -44,11 +42,46 @@ tf_cc_test(
deps = [ deps = [
":transform_utils", ":transform_utils",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "file_utils",
srcs = [
"file_utils.cc",
],
hdrs = [
"file_utils.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)
tf_cc_test(
name = "file_utils_test",
size = "small",
srcs = ["file_utils_test.cc"],
deps = [
":file_utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
@ -154,6 +187,7 @@ cc_library(
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":file_utils",
":transform_utils", ":transform_utils",
":transforms_lib", ":transforms_lib",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -215,6 +249,7 @@ cc_library(
copts = tf_copts(), copts = tf_copts(),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":file_utils",
":transform_utils", ":transform_utils",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -240,6 +275,7 @@ cc_binary(
linkstatic = 1, linkstatic = 1,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":file_utils",
":transform_utils", ":transform_utils",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/core/util/equal_graph_def.h" #include "tensorflow/core/util/equal_graph_def.h"
#include "tensorflow/tools/graph_transforms/file_utils.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow { namespace tensorflow {

View File

@ -0,0 +1,46 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/file_utils.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace graph_transforms {
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
string file_data;
Status load_file_status =
ReadFileToString(Env::Default(), file_name, &file_data);
if (!load_file_status.ok()) {
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
return load_file_status;
}
// Try to load in binary format first, and then try ascii if that fails.
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
if (!load_status.ok()) {
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
load_status = Status::OK();
} else {
errors::AppendToMessage(&load_status,
" (both text and binary parsing failed for file ",
file_name, ")");
}
}
return load_status;
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,32 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace graph_transforms {
// First tries to load the file as a text protobuf, if that fails tries to parse
// it as a binary protobuf, and returns an error if both fail.
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def);
} // namespace graph_transforms
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_

View File

@ -0,0 +1,83 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/graph_transforms/file_utils.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace graph_transforms {
class FileUtilsTest : public ::testing::Test {
protected:
void TestLoadTextOrBinaryGraphFile() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
auto root = tensorflow::Scope::NewRootScope();
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
const string text_file =
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
const string binary_file =
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
TF_ASSERT_OK(
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
GraphDef text_graph_def;
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
string text_diff;
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
<< text_diff;
GraphDef binary_graph_def;
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
string binary_diff;
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
<< binary_diff;
GraphDef no_graph_def;
EXPECT_FALSE(
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
.ok());
GraphDef bogus_graph_def;
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
}
};
TEST_F(FileUtilsTest, TestLoadTextOrBinaryGraphFile) {
TestLoadTextOrBinaryGraphFile();
}
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/file_utils.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow { namespace tensorflow {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/file_utils.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow { namespace tensorflow {

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow { namespace tensorflow {
namespace graph_transforms { namespace graph_transforms {
@ -587,28 +586,6 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
return Status::OK(); return Status::OK();
} }
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
string file_data;
Status load_file_status =
ReadFileToString(Env::Default(), file_name, &file_data);
if (!load_file_status.ok()) {
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
return load_file_status;
}
// Try to load in binary format first, and then try ascii if that fails.
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
if (!load_status.ok()) {
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
load_status = Status::OK();
} else {
errors::AppendToMessage(&load_status,
" (both text and binary parsing failed for file ",
file_name, ")");
}
}
return load_status;
}
int TransformFuncContext::CountParameters(const string& name) const { int TransformFuncContext::CountParameters(const string& name) const {
if (params.count(name)) { if (params.count(name)) {
return params.at(name).size(); return params.at(name).size();

View File

@ -109,8 +109,8 @@ void FilterGraphDef(const GraphDef& input_graph_def,
std::function<bool(const NodeDef&)> selector, std::function<bool(const NodeDef&)> selector,
GraphDef* output_graph_def); GraphDef* output_graph_def);
// Creates a copy of the input graph, with all occurrences of the attributes with // Creates a copy of the input graph, with all occurrences of the attributes
// the names in the argument removed from the node defs. // with the names in the argument removed from the node defs.
void RemoveAttributes(const GraphDef& input_graph_def, void RemoveAttributes(const GraphDef& input_graph_def,
const std::vector<string>& attributes, const std::vector<string>& attributes,
GraphDef* output_graph_def); GraphDef* output_graph_def);
@ -133,10 +133,6 @@ Status IsGraphValid(const GraphDef& graph_def);
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
DataTypeVector* outputs); DataTypeVector* outputs);
// First tries to load the file as a text protobuf, if that fails tries to parse
// it as a binary protobuf, and returns an error if both fail.
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
// This is used to spot particular subgraphs in a larger model. To use it, // This is used to spot particular subgraphs in a larger model. To use it,
// create a pattern like: // create a pattern like:
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});

View File

@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow { namespace tensorflow {
namespace graph_transforms { namespace graph_transforms {
@ -1066,50 +1064,6 @@ class TransformUtilsTest : public ::testing::Test {
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value)); TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
EXPECT_TRUE(value); EXPECT_TRUE(value);
} }
void TestLoadTextOrBinaryGraphFile() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 10;
auto root = tensorflow::Scope::NewRootScope();
Tensor a_data(DT_FLOAT, TensorShape({width}));
test::FillIota<float>(&a_data, 1.0f);
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
const string text_file =
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
const string binary_file =
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
TF_ASSERT_OK(
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
GraphDef text_graph_def;
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
string text_diff;
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
<< text_diff;
GraphDef binary_graph_def;
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
string binary_diff;
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
<< binary_diff;
GraphDef no_graph_def;
EXPECT_FALSE(
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
.ok());
GraphDef bogus_graph_def;
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
}
}; };
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
@ -1206,9 +1160,5 @@ TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
TestGetOneBoolParameter(); TestGetOneBoolParameter();
} }
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
TestLoadTextOrBinaryGraphFile();
}
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow