Modify dependency structure of transform_utils to allow use within Grappler.
PiperOrigin-RevId: 162998272
This commit is contained in:
		
							parent
							
								
									a976910e83
								
							
						
					
					
						commit
						6a93f10b81
					
				@ -30,6 +30,6 @@ cc_binary(
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "//tensorflow/core/kernels:remote_fused_graph_execute_utils",
 | 
			
		||||
        "//tensorflow/core/kernels/hexagon:graph_transferer",
 | 
			
		||||
        "//tensorflow/tools/graph_transforms:transform_utils",
 | 
			
		||||
        "//tensorflow/tools/graph_transforms:file_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/platform/init_main.h"
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
#include "tensorflow/core/util/command_line_flags.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5202,6 +5202,7 @@ tf_cc_test(
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:core_cpu_internal",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:tensorflow",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core:testlib",
 | 
			
		||||
 | 
			
		||||
@ -111,6 +111,7 @@ tf_cc_test(
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:core_cpu_internal",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:tensorflow",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core:testlib",
 | 
			
		||||
 | 
			
		||||
@ -26,14 +26,12 @@ cc_library(
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:core_cpu_internal",
 | 
			
		||||
        "//tensorflow/core:core_cpu_base",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:lib_internal",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "//tensorflow/core:tensorflow",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -44,11 +42,46 @@ tf_cc_test(
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":transform_utils",
 | 
			
		||||
        "//tensorflow/cc:cc_ops",
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:core_cpu_base",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "//tensorflow/core:tensorflow",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core:testlib",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "file_utils",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "file_utils.cc",
 | 
			
		||||
    ],
 | 
			
		||||
    hdrs = [
 | 
			
		||||
        "file_utils.h",
 | 
			
		||||
    ],
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "//tensorflow/core:tensorflow",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cc_test(
 | 
			
		||||
    name = "file_utils_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["file_utils_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":file_utils",
 | 
			
		||||
        "//tensorflow/cc:cc_ops",
 | 
			
		||||
        "//tensorflow/core:core_cpu",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core:testlib",
 | 
			
		||||
@ -154,6 +187,7 @@ cc_library(
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":file_utils",
 | 
			
		||||
        ":transform_utils",
 | 
			
		||||
        ":transforms_lib",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
@ -215,6 +249,7 @@ cc_library(
 | 
			
		||||
    copts = tf_copts(),
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":file_utils",
 | 
			
		||||
        ":transform_utils",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
@ -240,6 +275,7 @@ cc_binary(
 | 
			
		||||
    linkstatic = 1,
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":file_utils",
 | 
			
		||||
        ":transform_utils",
 | 
			
		||||
        "//tensorflow/core:core_cpu_internal",
 | 
			
		||||
        "//tensorflow/core:framework_internal",
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
#include "tensorflow/core/util/command_line_flags.h"
 | 
			
		||||
#include "tensorflow/core/util/equal_graph_def.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										46
									
								
								tensorflow/tools/graph_transforms/file_utils.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tensorflow/tools/graph_transforms/file_utils.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/public/session.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace graph_transforms {
 | 
			
		||||
 | 
			
		||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
 | 
			
		||||
  string file_data;
 | 
			
		||||
  Status load_file_status =
 | 
			
		||||
      ReadFileToString(Env::Default(), file_name, &file_data);
 | 
			
		||||
  if (!load_file_status.ok()) {
 | 
			
		||||
    errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
 | 
			
		||||
    return load_file_status;
 | 
			
		||||
  }
 | 
			
		||||
  // Try to load in binary format first, and then try ascii if that fails.
 | 
			
		||||
  Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
 | 
			
		||||
  if (!load_status.ok()) {
 | 
			
		||||
    if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
 | 
			
		||||
      load_status = Status::OK();
 | 
			
		||||
    } else {
 | 
			
		||||
      errors::AppendToMessage(&load_status,
 | 
			
		||||
                              " (both text and binary parsing failed for file ",
 | 
			
		||||
                              file_name, ")");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return load_status;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace graph_transforms
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
							
								
								
									
										32
									
								
								tensorflow/tools/graph_transforms/file_utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								tensorflow/tools/graph_transforms/file_utils.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
 | 
			
		||||
#define THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/framework/graph.pb.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace graph_transforms {
 | 
			
		||||
 | 
			
		||||
// First tries to load the file as a text protobuf, if that fails tries to parse
 | 
			
		||||
// it as a binary protobuf, and returns an error if both fail.
 | 
			
		||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def);
 | 
			
		||||
 | 
			
		||||
}  // namespace graph_transforms
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // THIRD_PARTY_TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_FILE_UTILS_H_
 | 
			
		||||
							
								
								
									
										83
									
								
								tensorflow/tools/graph_transforms/file_utils_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								tensorflow/tools/graph_transforms/file_utils_test.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,83 @@
 | 
			
		||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
 | 
			
		||||
#include "tensorflow/cc/ops/const_op.h"
 | 
			
		||||
#include "tensorflow/cc/ops/image_ops.h"
 | 
			
		||||
#include "tensorflow/cc/ops/nn_ops.h"
 | 
			
		||||
#include "tensorflow/cc/ops/standard_ops.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor_testutil.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status_test_util.h"
 | 
			
		||||
#include "tensorflow/core/lib/io/path.h"
 | 
			
		||||
#include "tensorflow/core/platform/test.h"
 | 
			
		||||
#include "tensorflow/core/platform/test_benchmark.h"
 | 
			
		||||
#include "tensorflow/core/util/equal_graph_def.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace graph_transforms {
 | 
			
		||||
 | 
			
		||||
class FileUtilsTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  void TestLoadTextOrBinaryGraphFile() {
 | 
			
		||||
    using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
 | 
			
		||||
    const int width = 10;
 | 
			
		||||
 | 
			
		||||
    auto root = tensorflow::Scope::NewRootScope();
 | 
			
		||||
    Tensor a_data(DT_FLOAT, TensorShape({width}));
 | 
			
		||||
    test::FillIota<float>(&a_data, 1.0f);
 | 
			
		||||
    Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
 | 
			
		||||
    GraphDef graph_def;
 | 
			
		||||
    TF_ASSERT_OK(root.ToGraphDef(&graph_def));
 | 
			
		||||
 | 
			
		||||
    const string text_file =
 | 
			
		||||
        io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
 | 
			
		||||
    TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
 | 
			
		||||
 | 
			
		||||
    const string binary_file =
 | 
			
		||||
        io::JoinPath(testing::TmpDir(), "binary_graph.pb");
 | 
			
		||||
    TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
 | 
			
		||||
 | 
			
		||||
    const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
 | 
			
		||||
    TF_ASSERT_OK(
 | 
			
		||||
        WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
 | 
			
		||||
 | 
			
		||||
    GraphDef text_graph_def;
 | 
			
		||||
    TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
 | 
			
		||||
    string text_diff;
 | 
			
		||||
    EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
 | 
			
		||||
        << text_diff;
 | 
			
		||||
 | 
			
		||||
    GraphDef binary_graph_def;
 | 
			
		||||
    TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
 | 
			
		||||
    string binary_diff;
 | 
			
		||||
    EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
 | 
			
		||||
        << binary_diff;
 | 
			
		||||
 | 
			
		||||
    GraphDef no_graph_def;
 | 
			
		||||
    EXPECT_FALSE(
 | 
			
		||||
        LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
 | 
			
		||||
            .ok());
 | 
			
		||||
 | 
			
		||||
    GraphDef bogus_graph_def;
 | 
			
		||||
    EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(FileUtilsTest, TestLoadTextOrBinaryGraphFile) {
 | 
			
		||||
  TestLoadTextOrBinaryGraphFile();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace graph_transforms
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
@ -32,6 +32,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/platform/init_main.h"
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
#include "tensorflow/core/util/command_line_flags.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/platform/init_main.h"
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
#include "tensorflow/core/util/command_line_flags.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/file_utils.h"
 | 
			
		||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,6 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/framework/op.h"
 | 
			
		||||
#include "tensorflow/core/lib/hash/hash.h"
 | 
			
		||||
#include "tensorflow/core/lib/strings/str_util.h"
 | 
			
		||||
#include "tensorflow/core/public/session.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace graph_transforms {
 | 
			
		||||
@ -587,28 +586,6 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def) {
 | 
			
		||||
  string file_data;
 | 
			
		||||
  Status load_file_status =
 | 
			
		||||
      ReadFileToString(Env::Default(), file_name, &file_data);
 | 
			
		||||
  if (!load_file_status.ok()) {
 | 
			
		||||
    errors::AppendToMessage(&load_file_status, " (for file ", file_name, ")");
 | 
			
		||||
    return load_file_status;
 | 
			
		||||
  }
 | 
			
		||||
  // Try to load in binary format first, and then try ascii if that fails.
 | 
			
		||||
  Status load_status = ReadBinaryProto(Env::Default(), file_name, graph_def);
 | 
			
		||||
  if (!load_status.ok()) {
 | 
			
		||||
    if (protobuf::TextFormat::ParseFromString(file_data, graph_def)) {
 | 
			
		||||
      load_status = Status::OK();
 | 
			
		||||
    } else {
 | 
			
		||||
      errors::AppendToMessage(&load_status,
 | 
			
		||||
                              " (both text and binary parsing failed for file ",
 | 
			
		||||
                              file_name, ")");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return load_status;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TransformFuncContext::CountParameters(const string& name) const {
 | 
			
		||||
  if (params.count(name)) {
 | 
			
		||||
    return params.at(name).size();
 | 
			
		||||
 | 
			
		||||
@ -109,8 +109,8 @@ void FilterGraphDef(const GraphDef& input_graph_def,
 | 
			
		||||
                    std::function<bool(const NodeDef&)> selector,
 | 
			
		||||
                    GraphDef* output_graph_def);
 | 
			
		||||
 | 
			
		||||
// Creates a copy of the input graph, with all occurrences of the attributes with
 | 
			
		||||
// the names in the argument removed from the node defs.
 | 
			
		||||
// Creates a copy of the input graph, with all occurrences of the attributes
 | 
			
		||||
// with the names in the argument removed from the node defs.
 | 
			
		||||
void RemoveAttributes(const GraphDef& input_graph_def,
 | 
			
		||||
                      const std::vector<string>& attributes,
 | 
			
		||||
                      GraphDef* output_graph_def);
 | 
			
		||||
@ -133,10 +133,6 @@ Status IsGraphValid(const GraphDef& graph_def);
 | 
			
		||||
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
 | 
			
		||||
                     DataTypeVector* outputs);
 | 
			
		||||
 | 
			
		||||
// First tries to load the file as a text protobuf, if that fails tries to parse
 | 
			
		||||
// it as a binary protobuf, and returns an error if both fail.
 | 
			
		||||
Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph);
 | 
			
		||||
 | 
			
		||||
// This is used to spot particular subgraphs in a larger model. To use it,
 | 
			
		||||
// create a pattern like:
 | 
			
		||||
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
 | 
			
		||||
 | 
			
		||||
@ -23,8 +23,6 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/lib/io/path.h"
 | 
			
		||||
#include "tensorflow/core/platform/test.h"
 | 
			
		||||
#include "tensorflow/core/platform/test_benchmark.h"
 | 
			
		||||
#include "tensorflow/core/public/session.h"
 | 
			
		||||
#include "tensorflow/core/util/equal_graph_def.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace graph_transforms {
 | 
			
		||||
@ -1066,50 +1064,6 @@ class TransformUtilsTest : public ::testing::Test {
 | 
			
		||||
    TF_EXPECT_OK(context.GetOneBoolParameter("not_present", true, &value));
 | 
			
		||||
    EXPECT_TRUE(value);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void TestLoadTextOrBinaryGraphFile() {
 | 
			
		||||
    using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
 | 
			
		||||
    const int width = 10;
 | 
			
		||||
 | 
			
		||||
    auto root = tensorflow::Scope::NewRootScope();
 | 
			
		||||
    Tensor a_data(DT_FLOAT, TensorShape({width}));
 | 
			
		||||
    test::FillIota<float>(&a_data, 1.0f);
 | 
			
		||||
    Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
 | 
			
		||||
    GraphDef graph_def;
 | 
			
		||||
    TF_ASSERT_OK(root.ToGraphDef(&graph_def));
 | 
			
		||||
 | 
			
		||||
    const string text_file =
 | 
			
		||||
        io::JoinPath(testing::TmpDir(), "text_graph.pbtxt");
 | 
			
		||||
    TF_ASSERT_OK(WriteTextProto(Env::Default(), text_file, graph_def));
 | 
			
		||||
 | 
			
		||||
    const string binary_file =
 | 
			
		||||
        io::JoinPath(testing::TmpDir(), "binary_graph.pb");
 | 
			
		||||
    TF_ASSERT_OK(WriteBinaryProto(Env::Default(), binary_file, graph_def));
 | 
			
		||||
 | 
			
		||||
    const string bogus_file = io::JoinPath(testing::TmpDir(), "bogus_graph.pb");
 | 
			
		||||
    TF_ASSERT_OK(
 | 
			
		||||
        WriteStringToFile(Env::Default(), bogus_file, "Not a !{ proto..."));
 | 
			
		||||
 | 
			
		||||
    GraphDef text_graph_def;
 | 
			
		||||
    TF_EXPECT_OK(LoadTextOrBinaryGraphFile(text_file, &text_graph_def));
 | 
			
		||||
    string text_diff;
 | 
			
		||||
    EXPECT_TRUE(EqualGraphDef(text_graph_def, graph_def, &text_diff))
 | 
			
		||||
        << text_diff;
 | 
			
		||||
 | 
			
		||||
    GraphDef binary_graph_def;
 | 
			
		||||
    TF_EXPECT_OK(LoadTextOrBinaryGraphFile(binary_file, &binary_graph_def));
 | 
			
		||||
    string binary_diff;
 | 
			
		||||
    EXPECT_TRUE(EqualGraphDef(binary_graph_def, graph_def, &binary_diff))
 | 
			
		||||
        << binary_diff;
 | 
			
		||||
 | 
			
		||||
    GraphDef no_graph_def;
 | 
			
		||||
    EXPECT_FALSE(
 | 
			
		||||
        LoadTextOrBinaryGraphFile("____non_existent_file_____", &no_graph_def)
 | 
			
		||||
            .ok());
 | 
			
		||||
 | 
			
		||||
    GraphDef bogus_graph_def;
 | 
			
		||||
    EXPECT_FALSE(LoadTextOrBinaryGraphFile(bogus_file, &bogus_graph_def).ok());
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
 | 
			
		||||
@ -1206,9 +1160,5 @@ TEST_F(TransformUtilsTest, TestGetOneBoolParameter) {
 | 
			
		||||
  TestGetOneBoolParameter();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TransformUtilsTest, TestLoadTextOrBinaryGraphFile) {
 | 
			
		||||
  TestLoadTextOrBinaryGraphFile();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace graph_transforms
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user