Modify dependency structure of transform_utils to allow use within Grappler.
PiperOrigin-RevId: 162998272
This commit is contained in:
parent
a976910e83
commit
6a93f10b81
@ -30,6 +30,6 @@ cc_binary(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||||
"//tensorflow/core/kernels/hexagon:graph_transferer",
|
"//tensorflow/core/kernels/hexagon:graph_transferer",
|
||||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
"//tensorflow/tools/graph_transforms:file_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
@ -5202,6 +5202,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
@ -111,6 +111,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
@ -26,14 +26,12 @@ cc_library(
|
|||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,11 +42,46 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "file_utils",
|
||||||
|
srcs = [
|
||||||
|
"file_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"file_utils.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "file_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["file_utils_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":file_utils",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
@ -154,6 +187,7 @@ cc_library(
|
|||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":file_utils",
|
||||||
":transform_utils",
|
":transform_utils",
|
||||||
":transforms_lib",
|
":transforms_lib",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
@ -215,6 +249,7 @@ cc_library(
|
|||||||
copts = tf_copts(),
|
copts = tf_copts(),
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":file_utils",
|
||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
@ -240,6 +275,7 @@ cc_binary(
|
|||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":file_utils",
|
||||||
":transform_utils",
|
":transform_utils",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
#include "tensorflow/core/util/equal_graph_def.h"
|
#include "tensorflow/core/util/equal_graph_def.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
46
tensorflow/tools/graph_transforms/file_utils.cc
Normal file
46
tensorflow/tools/graph_transforms/file_utils.cc
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
|
||||||
|
string file_data;
|
||||||
|
Status load_file_status =
|
||||||
|
ReadFileToString(Env::Default(), file_name, &file_data);
|
||||||
|
if (!load_file_status.ok()) {
|
||||||
|
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
|
||||||
|
return load_file_status;
|
||||||
|
}
|
||||||
|
// Try to load in binary format first, and then try ascii if that fails.
|
||||||
|
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
|
||||||
|
if (!load_status.ok()) {
|
||||||
|
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
|
||||||
|
load_status = Status::OK();
|
||||||
|
} else {
|
||||||
|
errors::AppendToMessage(&load_status,
|
||||||
|
" (both text and binary parsing failed for file ",
|
||||||
|
file_name, ")");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return load_status;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
32
tensorflow/tools/graph_transforms/file_utils.h
Normal file
32
tensorflow/tools/graph_transforms/file_utils.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// First tries to load the file as a text protobuf, if that fails tries to parse
|
||||||
|
// it as a binary protobuf, and returns an error if both fail.
|
||||||
|
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
|
83
tensorflow/tools/graph_transforms/file_utils_test.cc
Normal file
83
tensorflow/tools/graph_transforms/file_utils_test.cc
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/util/equal_graph_def.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
class FileUtilsTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void TestLoadTextOrBinaryGraphFile() {
|
||||||
|
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
const int width = 10;
|
||||||
|
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
||||||
|
test::FillIota<float>(&a_data, 1.0f);
|
||||||
|
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||||
|
|
||||||
|
const string text_file =
|
||||||
|
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
|
||||||
|
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
|
||||||
|
|
||||||
|
const string binary_file =
|
||||||
|
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
|
||||||
|
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
|
||||||
|
|
||||||
|
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
|
||||||
|
|
||||||
|
GraphDef text_graph_def;
|
||||||
|
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
|
||||||
|
string text_diff;
|
||||||
|
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
|
||||||
|
<< text_diff;
|
||||||
|
|
||||||
|
GraphDef binary_graph_def;
|
||||||
|
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
|
||||||
|
string binary_diff;
|
||||||
|
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
|
||||||
|
<< binary_diff;
|
||||||
|
|
||||||
|
GraphDef no_graph_def;
|
||||||
|
EXPECT_FALSE(
|
||||||
|
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
|
||||||
|
.ok());
|
||||||
|
|
||||||
|
GraphDef bogus_graph_def;
|
||||||
|
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FileUtilsTest, TestLoadTextOrBinaryGraphFile) {
|
||||||
|
TestLoadTextOrBinaryGraphFile();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/file_utils.h"
|
||||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/public/session.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace graph_transforms {
|
namespace graph_transforms {
|
||||||
@ -587,28 +586,6 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
|
|
||||||
string file_data;
|
|
||||||
Status load_file_status =
|
|
||||||
ReadFileToString(Env::Default(), file_name, &file_data);
|
|
||||||
if (!load_file_status.ok()) {
|
|
||||||
errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
|
|
||||||
return load_file_status;
|
|
||||||
}
|
|
||||||
// Try to load in binary format first, and then try ascii if that fails.
|
|
||||||
Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
|
|
||||||
if (!load_status.ok()) {
|
|
||||||
if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
|
|
||||||
load_status = Status::OK();
|
|
||||||
} else {
|
|
||||||
errors::AppendToMessage(&load_status,
|
|
||||||
" (both text and binary parsing failed for file ",
|
|
||||||
file_name, ")");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return load_status;
|
|
||||||
}
|
|
||||||
|
|
||||||
int TransformFuncContext::CountParameters(const string& name) const {
|
int TransformFuncContext::CountParameters(const string& name) const {
|
||||||
if (params.count(name)) {
|
if (params.count(name)) {
|
||||||
return params.at(name).size();
|
return params.at(name).size();
|
||||||
|
@ -109,8 +109,8 @@ void FilterGraphDef(const GraphDef& input_graph_def,
|
|||||||
std::function<bool(const NodeDef&)> selector,
|
std::function<bool(const NodeDef&)> selector,
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
// Creates a copy of the input graph, with all occurrences of the attributes with
|
// Creates a copy of the input graph, with all occurrences of the attributes
|
||||||
// the names in the argument removed from the node defs.
|
// with the names in the argument removed from the node defs.
|
||||||
void RemoveAttributes(const GraphDef& input_graph_def,
|
void RemoveAttributes(const GraphDef& input_graph_def,
|
||||||
const std::vector<string>& attributes,
|
const std::vector<string>& attributes,
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
@ -133,10 +133,6 @@ Status IsGraphValid(const GraphDef& graph_def);
|
|||||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
|
||||||
DataTypeVector* outputs);
|
DataTypeVector* outputs);
|
||||||
|
|
||||||
// First tries to load the file as a text protobuf, if that fails tries to parse
|
|
||||||
// it as a binary protobuf, and returns an error if both fail.
|
|
||||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
|
|
||||||
|
|
||||||
// This is used to spot particular subgraphs in a larger model. To use it,
|
// This is used to spot particular subgraphs in a larger model. To use it,
|
||||||
// create a pattern like:
|
// create a pattern like:
|
||||||
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
|
||||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
#include "tensorflow/core/public/session.h"
|
|
||||||
#include "tensorflow/core/util/equal_graph_def.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace graph_transforms {
|
namespace graph_transforms {
|
||||||
@ -1066,50 +1064,6 @@ class TransformUtilsTest : public ::testing::Test {
|
|||||||
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
|
TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
|
||||||
EXPECT_TRUE(value);
|
EXPECT_TRUE(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestLoadTextOrBinaryGraphFile() {
|
|
||||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
|
||||||
const int width = 10;
|
|
||||||
|
|
||||||
auto root = tensorflow::Scope::NewRootScope();
|
|
||||||
Tensor a_data(DT_FLOAT, TensorShape({width}));
|
|
||||||
test::FillIota<float>(&a_data, 1.0f);
|
|
||||||
Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
|
|
||||||
GraphDef graph_def;
|
|
||||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
|
||||||
|
|
||||||
const string text_file =
|
|
||||||
io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
|
|
||||||
TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
|
|
||||||
|
|
||||||
const string binary_file =
|
|
||||||
io::JoinPath(testing::TmpDir(), "binary_graph.pb");
|
|
||||||
TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
|
|
||||||
|
|
||||||
const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
|
|
||||||
TF_ASSERT_OK(
|
|
||||||
WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
|
|
||||||
|
|
||||||
GraphDef text_graph_def;
|
|
||||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
|
|
||||||
string text_diff;
|
|
||||||
EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
|
|
||||||
<< text_diff;
|
|
||||||
|
|
||||||
GraphDef binary_graph_def;
|
|
||||||
TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
|
|
||||||
string binary_diff;
|
|
||||||
EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
|
|
||||||
<< binary_diff;
|
|
||||||
|
|
||||||
GraphDef no_graph_def;
|
|
||||||
EXPECT_FALSE(
|
|
||||||
LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
|
|
||||||
.ok());
|
|
||||||
|
|
||||||
GraphDef bogus_graph_def;
|
|
||||||
EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
|
||||||
@ -1206,9 +1160,5 @@ TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
|
|||||||
TestGetOneBoolParameter();
|
TestGetOneBoolParameter();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
|
|
||||||
TestLoadTextOrBinaryGraphFile();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user