diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 5309724580c..3497c84d582 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -30,6 +30,6 @@ cc_binary( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:remote_fused_graph_execute_utils", "//tensorflow/core/kernels/hexagon:graph_transferer", - "//tensorflow/tools/graph_transforms:transform_utils", + "//tensorflow/tools/graph_transforms:file_utils", ], ) diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 8fc3f35f417..60281951dda 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index e676d5a3670..fffcb980db2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5202,6 +5202,7 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/core/kernels/hexagon/BUILD b/tensorflow/core/kernels/hexagon/BUILD index ce83d7c0fcb..a01a4c40dae 100644 --- a/tensorflow/core/kernels/hexagon/BUILD +++ b/tensorflow/core/kernels/hexagon/BUILD @@ -111,6 +111,7 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index ec582739b74..cad0567b9e9 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -26,14 +26,12 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", ], ) @@ -44,11 +42,46 @@ tf_cc_test( deps = [ ":transform_utils", "//tensorflow/cc:cc_ops", - "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "file_utils", + srcs = [ + "file_utils.cc", + ], + hdrs = [ + "file_utils.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +tf_cc_test( + name = "file_utils_test", + size = "small", + srcs = ["file_utils_test.cc"], + deps = [ + ":file_utils", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -154,6 +187,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":file_utils", ":transform_utils", ":transforms_lib", "//tensorflow/core:framework_internal", @@ -215,6 +249,7 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":file_utils", ":transform_utils", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -240,6 +275,7 @@ cc_binary( linkstatic = 1, visibility = ["//visibility:public"], deps = [ + ":file_utils", ":transform_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", diff --git a/tensorflow/tools/graph_transforms/compare_graphs.cc b/tensorflow/tools/graph_transforms/compare_graphs.cc index 8fce16337f7..28a80a885f8 100644 --- a/tensorflow/tools/graph_transforms/compare_graphs.cc +++ b/tensorflow/tools/graph_transforms/compare_graphs.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/equal_graph_def.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { diff --git a/tensorflow/tools/graph_transforms/file_utils.cc b/tensorflow/tools/graph_transforms/file_utils.cc new file mode 100644 index 00000000000..5649c971982 --- /dev/null +++ b/tensorflow/tools/graph_transforms/file_utils.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/file_utils.h" + +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace graph_transforms { + +Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) { + string file_data; + Status load_file_status = + ReadFileToString(Env::Default(), file_name, &file_data); + if (!load_file_status.ok()) { + errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")"); + return load_file_status; + } + // Try to load in binary format first, and then try ascii if that fails. + Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def); + if (!load_status.ok()) { + if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) { + load_status = Status::OK(); + } else { + errors::AppendToMessage(&load_status, + " (both text and binary parsing failed for file ", + file_name, ")"); + } + } + return load_status; +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/file_utils.h b/tensorflow/tools/graph_transforms/file_utils.h new file mode 100644 index 00000000000..4737e95abce --- /dev/null +++ b/tensorflow/tools/graph_transforms/file_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace graph_transforms { + +// First tries to load the file as a text protobuf, if that fails tries to parse +// it as a binary protobuf, and returns an error if both fail. +Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def); + +} // namespace graph_transforms +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_ diff --git a/tensorflow/tools/graph_transforms/file_utils_test.cc b/tensorflow/tools/graph_transforms/file_utils_test.cc new file mode 100644 index 00000000000..8c898ba0f16 --- /dev/null +++ b/tensorflow/tools/graph_transforms/file_utils_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/file_utils.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { +namespace graph_transforms { + +class FileUtilsTest : public ::testing::Test { + protected: + void TestLoadTextOrBinaryGraphFile() { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + const int width = 10; + + auto root = tensorflow::Scope::NewRootScope(); + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota(&a_data, 1.0f); + Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + const string text_file = + io::JoinPath(testing::TmpDir(), "text_graph.pbtxt"); + TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def)); + + const string binary_file = + io::JoinPath(testing::TmpDir(), "binary_graph.pb"); + TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def)); + + const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb"); + TF_ASSERT_OK( + WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto...")); + + GraphDef text_graph_def; + TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def)); + string text_diff; + EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff)) + << text_diff; + + GraphDef binary_graph_def; + TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def)); + string binary_diff; + EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff)) + << binary_diff; + + GraphDef no_graph_def; + EXPECT_FALSE( + LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def) + .ok()); + + GraphDef bogus_graph_def; + EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok()); + } +}; + +TEST_F(FileUtilsTest, TestLoadTextOrBinaryGraphFile) { + TestLoadTextOrBinaryGraphFile(); +} + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/summarize_graph_main.cc b/tensorflow/tools/graph_transforms/summarize_graph_main.cc index d9806e8ed25..6c404c8061e 100644 --- a/tensorflow/tools/graph_transforms/summarize_graph_main.cc +++ b/tensorflow/tools/graph_transforms/summarize_graph_main.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc index e7694104cbd..28387c2b48c 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.cc +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/file_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 0ef517acc5b..bd1e4c90c06 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/public/session.h" namespace tensorflow { namespace graph_transforms { @@ -587,28 +586,6 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, return Status::OK(); } -Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) { - string file_data; - Status load_file_status = - ReadFileToString(Env::Default(), file_name, &file_data); - if (!load_file_status.ok()) { - errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")"); - return load_file_status; - } - // Try to load in binary format first, and then try ascii if that fails. - Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def); - if (!load_status.ok()) { - if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) { - load_status = Status::OK(); - } else { - errors::AppendToMessage(&load_status, - " (both text and binary parsing failed for file ", - file_name, ")"); - } - } - return load_status; -} - int TransformFuncContext::CountParameters(const string& name) const { if (params.count(name)) { return params.at(name).size(); diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 2db0a24267b..c0fb4924123 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -109,8 +109,8 @@ void FilterGraphDef(const GraphDef& input_graph_def, std::function selector, GraphDef* output_graph_def); -// Creates a copy of the input graph, with all occurrences of the attributes with -// the names in the argument removed from the node defs. +// Creates a copy of the input graph, with all occurrences of the attributes +// with the names in the argument removed from the node defs. void RemoveAttributes(const GraphDef& input_graph_def, const std::vector& attributes, GraphDef* output_graph_def); @@ -133,10 +133,6 @@ Status IsGraphValid(const GraphDef& graph_def); Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, DataTypeVector* outputs); -// First tries to load the file as a text protobuf, if that fails tries to parse -// it as a binary protobuf, and returns an error if both fail. -Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph); - // This is used to spot particular subgraphs in a larger model. To use it, // create a pattern like: // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index d068254b35f..b5bc2d75fd2 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace graph_transforms { @@ -1066,50 +1064,6 @@ class TransformUtilsTest : public ::testing::Test { TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value)); EXPECT_TRUE(value); } - - void TestLoadTextOrBinaryGraphFile() { - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - const int width = 10; - - auto root = tensorflow::Scope::NewRootScope(); - Tensor a_data(DT_FLOAT, TensorShape({width})); - test::FillIota(&a_data, 1.0f); - Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - - const string text_file = - io::JoinPath(testing::TmpDir(), "text_graph.pbtxt"); - TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def)); - - const string binary_file = - io::JoinPath(testing::TmpDir(), "binary_graph.pb"); - TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def)); - - const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb"); - TF_ASSERT_OK( - WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto...")); - - GraphDef text_graph_def; - TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def)); - string text_diff; - EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff)) - << text_diff; - - GraphDef binary_graph_def; - TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def)); - string binary_diff; - EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff)) - << binary_diff; - - GraphDef no_graph_def; - EXPECT_FALSE( - LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def) - .ok()); - - GraphDef bogus_graph_def; - EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok()); - } }; TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } @@ -1206,9 +1160,5 @@ TEST_F(TransformUtilsTest, TestGetOneBoolParameter) { TestGetOneBoolParameter(); } -TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) { - TestLoadTextOrBinaryGraphFile(); -} - } // namespace graph_transforms } // namespace tensorflow