Modify dependency structure of transform_utils to allow use within Grappler.
PiperOrigin-RevId: 162998272
This commit is contained in:
parent
a976910e83
commit
6a93f10b81
@ -30,6 +30,6 @@ cc_binary(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
"//tensorflow/core/kernels/hexagon:graph_transferer",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
"//tensorflow/tools/graph_transforms:file_utils",
|
||||
],
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -5202,6 +5202,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
|
@ -111,6 +111,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
|
@ -26,14 +26,12 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
@ -44,11 +42,46 @@ tf_cc_test(
|
||||
deps = [
|
||||
":transform_utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "file_utils",
|
||||
srcs = [
|
||||
"file_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"file_utils.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "file_utils_test",
|
||||
size = "small",
|
||||
srcs = ["file_utils_test.cc"],
|
||||
deps = [
|
||||
":file_utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
@ -154,6 +187,7 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":file_utils",
|
||||
":transform_utils",
|
||||
":transforms_lib",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -215,6 +249,7 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":file_utils",
|
||||
":transform_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -240,6 +275,7 @@ cc_binary(
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":file_utils",
|
||||
":transform_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
46
tensorflow/tools/graph_transforms/file_utils.cc
Normal file
46
tensorflow/tools/graph_transforms/file_utils.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
|
||||
string file_data;
|
||||
Status load_file_status =
|
||||
ReadFileToString(Env::Default(), file_name, &file_data);
|
||||
if (!load_file_status.ok()) {
|
||||
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
|
||||
return load_file_status;
|
||||
}
|
||||
// Try to load in binary format first, and then try ascii if that fails.
|
||||
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
|
||||
if (!load_status.ok()) {
|
||||
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
|
||||
load_status = Status::OK();
|
||||
} else {
|
||||
errors::AppendToMessage(&load_status,
|
||||
" (both text and binary parsing failed for file ",
|
||||
file_name, ")");
|
||||
}
|
||||
}
|
||||
return load_status;
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
32
tensorflow/tools/graph_transforms/file_utils.h
Normal file
32
tensorflow/tools/graph_transforms/file_utils.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// First tries to load the file as a text protobuf, if that fails tries to parse
|
||||
// it as a binary protobuf, and returns an error if both fail.
|
||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
|
83
tensorflow/tools/graph_transforms/file_utils_test.cc
Normal file
83
tensorflow/tools/graph_transforms/file_utils_test.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
class FileUtilsTest : public ::testing::Test {
|
||||
protected:
|
||||
void TestLoadTextOrBinaryGraphFile() {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
const int width = 10;
|
||||
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||
test::FillIota<float>(&a_data, 1.0f);
|
||||
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
const string text_file =
|
||||
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
|
||||
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
|
||||
|
||||
const string binary_file =
|
||||
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
|
||||
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
|
||||
|
||||
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
|
||||
TF_ASSERT_OK(
|
||||
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
|
||||
|
||||
GraphDef text_graph_def;
|
||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
|
||||
string text_diff;
|
||||
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
|
||||
<< text_diff;
|
||||
|
||||
GraphDef binary_graph_def;
|
||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
|
||||
string binary_diff;
|
||||
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
|
||||
<< binary_diff;
|
||||
|
||||
GraphDef no_graph_def;
|
||||
EXPECT_FALSE(
|
||||
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
|
||||
.ok());
|
||||
|
||||
GraphDef bogus_graph_def;
|
||||
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FileUtilsTest, TestLoadTextOrBinaryGraphFile) {
|
||||
TestLoadTextOrBinaryGraphFile();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
@ -587,28 +586,6 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
|
||||
string file_data;
|
||||
Status load_file_status =
|
||||
ReadFileToString(Env::Default(), file_name, &file_data);
|
||||
if (!load_file_status.ok()) {
|
||||
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
|
||||
return load_file_status;
|
||||
}
|
||||
// Try to load in binary format first, and then try ascii if that fails.
|
||||
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
|
||||
if (!load_status.ok()) {
|
||||
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
|
||||
load_status = Status::OK();
|
||||
} else {
|
||||
errors::AppendToMessage(&load_status,
|
||||
" (both text and binary parsing failed for file ",
|
||||
file_name, ")");
|
||||
}
|
||||
}
|
||||
return load_status;
|
||||
}
|
||||
|
||||
int TransformFuncContext::CountParameters(const string& name) const {
|
||||
if (params.count(name)) {
|
||||
return params.at(name).size();
|
||||
|
@ -109,8 +109,8 @@ void FilterGraphDef(const GraphDef& input_graph_def,
|
||||
std::function<bool(const NodeDef&)> selector,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
// Creates a copy of the input graph, with all occurrences of the attributes with
|
||||
// the names in the argument removed from the node defs.
|
||||
// Creates a copy of the input graph, with all occurrences of the attributes
|
||||
// with the names in the argument removed from the node defs.
|
||||
void RemoveAttributes(const GraphDef& input_graph_def,
|
||||
const std::vector<string>& attributes,
|
||||
GraphDef* output_graph_def);
|
||||
@ -133,10 +133,6 @@ Status IsGraphValid(const GraphDef& graph_def);
|
||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||
DataTypeVector* outputs);
|
||||
|
||||
// First tries to load the file as a text protobuf, if that fails tries to parse
|
||||
// it as a binary protobuf, and returns an error if both fail.
|
||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
|
||||
|
||||
// This is used to spot particular subgraphs in a larger model. To use it,
|
||||
// create a pattern like:
|
||||
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
@ -1066,50 +1064,6 @@ class TransformUtilsTest : public ::testing::Test {
|
||||
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
|
||||
EXPECT_TRUE(value);
|
||||
}
|
||||
|
||||
void TestLoadTextOrBinaryGraphFile() {
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
const int width = 10;
|
||||
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||
test::FillIota<float>(&a_data, 1.0f);
|
||||
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
|
||||
const string text_file =
|
||||
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
|
||||
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
|
||||
|
||||
const string binary_file =
|
||||
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
|
||||
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
|
||||
|
||||
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
|
||||
TF_ASSERT_OK(
|
||||
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
|
||||
|
||||
GraphDef text_graph_def;
|
||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
|
||||
string text_diff;
|
||||
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
|
||||
<< text_diff;
|
||||
|
||||
GraphDef binary_graph_def;
|
||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
|
||||
string binary_diff;
|
||||
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
|
||||
<< binary_diff;
|
||||
|
||||
GraphDef no_graph_def;
|
||||
EXPECT_FALSE(
|
||||
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
|
||||
.ok());
|
||||
|
||||
GraphDef bogus_graph_def;
|
||||
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
||||
@ -1206,9 +1160,5 @@ TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
|
||||
TestGetOneBoolParameter();
|
||||
}
|
||||
|
||||
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
|
||||
TestLoadTextOrBinaryGraphFile();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user