Script to create ApiDef files automatically based on OpDef.
PiperOrigin-RevId: 182463327
This commit is contained in:
parent
9f07ca8df9
commit
547d70e46c
tensorflow
@ -574,6 +574,7 @@ filegroup(
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/contrib/verbs:all_files",
|
||||
"//tensorflow/core:all_files",
|
||||
"//tensorflow/core/api_def:all_files",
|
||||
"//tensorflow/core/debug:all_files",
|
||||
"//tensorflow/core/distributed_runtime:all_files",
|
||||
"//tensorflow/core/distributed_runtime/rpc:all_files",
|
||||
|
@ -421,7 +421,7 @@ tf_cc_test(
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
api_def_srcs = ["//tensorflow/core:base_api_def"],
|
||||
api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
"audio_ops",
|
||||
@ -526,7 +526,7 @@ cc_library_with_android_deps(
|
||||
],
|
||||
copts = tf_copts(),
|
||||
data = [
|
||||
"//tensorflow/core:base_api_def",
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -3456,36 +3456,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "base_api_def",
|
||||
srcs = glob(["api_def/base_api/*"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "python_api_def",
|
||||
srcs = glob(["api_def/python_api/*"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "api_test",
|
||||
srcs = ["api_def/api_test.cc"],
|
||||
data = [
|
||||
":base_api_def",
|
||||
],
|
||||
deps = [
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":lib_test_internal",
|
||||
":op_gen_lib",
|
||||
":ops",
|
||||
":protos_all_cc",
|
||||
":test",
|
||||
":test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test_gpu(
|
||||
name = "device_tracer_test",
|
||||
size = "small",
|
||||
|
113
tensorflow/core/api_def/BUILD
Normal file
113
tensorflow/core/api_def/BUILD
Normal file
@ -0,0 +1,113 @@
|
||||
# Description:
|
||||
# Provides ApiDef access and ApiDef validation for TensorFlow.
|
||||
#
|
||||
# The following targets can be used to access ApiDefs:
|
||||
# :base_api_def
|
||||
# :python_api_def
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "base_api_def",
|
||||
srcs = glob(["base_api/*"]),
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "python_api_def",
|
||||
srcs = glob(["python_api/*"]),
|
||||
visibility = ["//tensorflow:internal"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "excluded_ops_lib",
|
||||
srcs = ["excluded_ops.cc"],
|
||||
hdrs = ["excluded_ops.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "update_api_def_lib",
|
||||
srcs = ["update_api_def.cc"],
|
||||
hdrs = ["update_api_def.h"],
|
||||
deps = [
|
||||
":excluded_ops_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "update_api_def_test",
|
||||
srcs = ["update_api_def_test.cc"],
|
||||
deps = [
|
||||
":update_api_def_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "update_api_def",
|
||||
srcs = [
|
||||
"update_api_def_main.cc",
|
||||
],
|
||||
data = [
|
||||
":base_api_def",
|
||||
],
|
||||
deps = [
|
||||
":update_api_def_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "api_test",
|
||||
srcs = ["api_test.cc"],
|
||||
data = [
|
||||
":base_api_def",
|
||||
],
|
||||
deps = [
|
||||
":excluded_ops_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_test_internal",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/api_def/excluded_ops.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
@ -44,15 +44,6 @@ constexpr char kDefaultApiDefDir[] =
|
||||
constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt";
|
||||
} // namespace
|
||||
|
||||
// Returns a list of ops excluded from ApiDef.
|
||||
// TODO(annarev): figure out if we should keep ApiDefs for these ops as well.
|
||||
const std::unordered_set<string>* GetExcludedOps() {
|
||||
static std::unordered_set<string>* excluded_ops =
|
||||
new std::unordered_set<string>(
|
||||
{"BigQueryReader", "GenerateBigQueryReaderPartitions"});
|
||||
return excluded_ops;
|
||||
}
|
||||
|
||||
// Reads golden ApiDef files and returns a map from file name to ApiDef file
|
||||
// contents.
|
||||
void GetGoldenApiDefs(Env* env, const string& api_files_dir,
|
||||
|
26
tensorflow/core/api_def/excluded_ops.cc
Normal file
26
tensorflow/core/api_def/excluded_ops.cc
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/api_def/excluded_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::unordered_set<std::string>* GetExcludedOps() {
|
||||
static std::unordered_set<std::string>* excluded_ops =
|
||||
new std::unordered_set<std::string>(
|
||||
{"BigQueryReader", "GenerateBigQueryReaderPartitions"});
|
||||
return excluded_ops;
|
||||
}
|
||||
} // namespace tensorflow
|
28
tensorflow/core/api_def/excluded_ops.h
Normal file
28
tensorflow/core/api_def/excluded_ops.h
Normal file
@ -0,0 +1,28 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_
|
||||
#define TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns a list of ops excluded from ApiDef.
|
||||
// TODO(annarev): figure out if we should keep ApiDefs for these ops as well
|
||||
const std::unordered_set<std::string>* GetExcludedOps();
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_
|
272
tensorflow/core/api_def/update_api_def.cc
Normal file
272
tensorflow/core/api_def/update_api_def.cc
Normal file
@ -0,0 +1,272 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/api_def/update_api_def.h"
|
||||
|
||||
#include <ctype.h>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/api_def/excluded_ops.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt";
|
||||
// TODO(annarev): look into supporting other prefixes, not just 'doc'.
|
||||
constexpr char kDocStart[] = ".Doc(R\"doc(";
|
||||
constexpr char kDocEnd[] = ")doc\")";
|
||||
|
||||
// Updates api_def based on the given op.
|
||||
void FillBaseApiDef(ApiDef* api_def, const OpDef& op) {
|
||||
api_def->set_graph_op_name(op.name());
|
||||
// Add arg docs
|
||||
for (auto& input_arg : op.input_arg()) {
|
||||
if (!input_arg.description().empty()) {
|
||||
auto* api_def_in_arg = api_def->add_in_arg();
|
||||
api_def_in_arg->set_name(input_arg.name());
|
||||
api_def_in_arg->set_description(input_arg.description());
|
||||
}
|
||||
}
|
||||
for (auto& output_arg : op.output_arg()) {
|
||||
if (!output_arg.description().empty()) {
|
||||
auto* api_def_out_arg = api_def->add_out_arg();
|
||||
api_def_out_arg->set_name(output_arg.name());
|
||||
api_def_out_arg->set_description(output_arg.description());
|
||||
}
|
||||
}
|
||||
// Add attr docs
|
||||
for (auto& attr : op.attr()) {
|
||||
if (!attr.description().empty()) {
|
||||
auto* api_def_attr = api_def->add_attr();
|
||||
api_def_attr->set_name(attr.name());
|
||||
api_def_attr->set_description(attr.description());
|
||||
}
|
||||
}
|
||||
// Add docs
|
||||
api_def->set_summary(op.summary());
|
||||
api_def->set_description(op.description());
|
||||
}
|
||||
|
||||
// Returns true if op has any description or summary.
|
||||
bool OpHasDocs(const OpDef& op) {
|
||||
if (!op.summary().empty() || !op.description().empty()) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& arg : op.input_arg()) {
|
||||
if (!arg.description().empty()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto& arg : op.output_arg()) {
|
||||
if (!arg.description().empty()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
for (const auto& attr : op.attr()) {
|
||||
if (!attr.description().empty()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns true if summary and all descriptions are the same in op1
|
||||
// and op2.
|
||||
bool CheckDocsMatch(const OpDef& op1, const OpDef& op2) {
|
||||
if (op1.summary() != op2.summary() ||
|
||||
op1.description() != op2.description() ||
|
||||
op1.input_arg_size() != op2.input_arg_size() ||
|
||||
op1.output_arg_size() != op2.output_arg_size() ||
|
||||
op1.attr_size() != op2.attr_size()) {
|
||||
return false;
|
||||
}
|
||||
// Iterate over args and attrs to compare their docs.
|
||||
for (int i = 0; i < op1.input_arg_size(); ++i) {
|
||||
if (op1.input_arg(i).description() != op2.input_arg(i).description()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op1.output_arg_size(); ++i) {
|
||||
if (op1.output_arg(i).description() != op2.output_arg(i).description()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < op1.attr_size(); ++i) {
|
||||
if (op1.attr(i).description() != op2.attr(i).description()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns true if descriptions and summaries in op match a
|
||||
// given single doc-string.
|
||||
bool ValidateOpDocs(const OpDef& op, const string& doc) {
|
||||
OpDefBuilder b(op.name());
|
||||
// We don't really care about type we use for arguments and
|
||||
// attributes. We just want to make sure attribute and argument names
|
||||
// are added so that descriptions can be assigned to them when parsing
|
||||
// documentation.
|
||||
for (const auto& arg : op.input_arg()) {
|
||||
b.Input(arg.name() + ":string");
|
||||
}
|
||||
for (const auto& arg : op.output_arg()) {
|
||||
b.Output(arg.name() + ":string");
|
||||
}
|
||||
for (const auto& attr : op.attr()) {
|
||||
b.Attr(attr.name() + ":string");
|
||||
}
|
||||
b.Doc(doc);
|
||||
OpRegistrationData op_reg_data;
|
||||
TF_CHECK_OK(b.Finalize(&op_reg_data));
|
||||
return CheckDocsMatch(op, op_reg_data.op_def);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
string RemoveDoc(const OpDef& op, const string& file_contents,
|
||||
size_t start_location) {
|
||||
// Look for a line starting with .Doc( after the REGISTER_OP.
|
||||
const auto doc_start_location = file_contents.find(kDocStart, start_location);
|
||||
const string format_error = strings::Printf(
|
||||
"Could not find %s doc for removal. Make sure the doc is defined with "
|
||||
"'%s' prefix and '%s' suffix or remove the doc manually.",
|
||||
op.name().c_str(), kDocStart, kDocEnd);
|
||||
if (doc_start_location == string::npos) {
|
||||
std::cerr << format_error << std::endl;
|
||||
LOG(ERROR) << "Didn't find doc start";
|
||||
return file_contents;
|
||||
}
|
||||
const auto doc_end_location = file_contents.find(kDocEnd, doc_start_location);
|
||||
if (doc_end_location == string::npos) {
|
||||
LOG(ERROR) << "Didn't find doc start";
|
||||
std::cerr << format_error << std::endl;
|
||||
return file_contents;
|
||||
}
|
||||
|
||||
const auto doc_start_size = sizeof(kDocStart) - 1;
|
||||
string doc_text = file_contents.substr(
|
||||
doc_start_location + doc_start_size,
|
||||
doc_end_location - doc_start_location - doc_start_size);
|
||||
|
||||
// Make sure the doc text we found actually matches OpDef docs to
|
||||
// avoid removing incorrect text.
|
||||
if (!ValidateOpDocs(op, doc_text)) {
|
||||
LOG(ERROR) << "Invalid doc: " << doc_text;
|
||||
std::cerr << format_error << std::endl;
|
||||
return file_contents;
|
||||
}
|
||||
// Remove .Doc call.
|
||||
auto before_doc = file_contents.substr(0, doc_start_location);
|
||||
str_util::StripTrailingWhitespace(&before_doc);
|
||||
return before_doc +
|
||||
file_contents.substr(doc_end_location + sizeof(kDocEnd) - 1);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Remove .Doc calls that follow REGISTER_OP calls for the given ops.
|
||||
// We search for REGISTER_OP calls in the given op_files list.
|
||||
void RemoveDocs(const std::vector<const OpDef*>& ops,
|
||||
const std::vector<string>& op_files) {
|
||||
// Set of ops that we already found REGISTER_OP calls for.
|
||||
std::set<string> processed_ops;
|
||||
|
||||
for (const auto& file : op_files) {
|
||||
string file_contents;
|
||||
bool file_contents_updated = false;
|
||||
TF_CHECK_OK(ReadFileToString(Env::Default(), file, &file_contents));
|
||||
|
||||
for (auto op : ops) {
|
||||
if (processed_ops.find(op->name()) != processed_ops.end()) {
|
||||
// We already found REGISTER_OP call for this op in another file.
|
||||
continue;
|
||||
}
|
||||
string register_call =
|
||||
strings::Printf("REGISTER_OP(\"%s\")", op->name().c_str());
|
||||
const auto register_call_location = file_contents.find(register_call);
|
||||
// Find REGISTER_OP(OpName) call.
|
||||
if (register_call_location == string::npos) {
|
||||
continue;
|
||||
}
|
||||
std::cout << "Removing .Doc call for " << op->name() << " from " << file
|
||||
<< "." << std::endl;
|
||||
file_contents = RemoveDoc(*op, file_contents, register_call_location);
|
||||
file_contents_updated = true;
|
||||
|
||||
processed_ops.insert(op->name());
|
||||
}
|
||||
if (file_contents_updated) {
|
||||
TF_CHECK_OK(WriteStringToFile(Env::Default(), file, file_contents))
|
||||
<< "Could not remove .Doc calls in " << file
|
||||
<< ". Make sure the file is writable.";
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Returns ApiDef text representation in multi-line format
|
||||
// constructed based on the given op.
|
||||
string CreateApiDef(const OpDef& op) {
|
||||
ApiDef api_def;
|
||||
FillBaseApiDef(&api_def, op);
|
||||
|
||||
const std::vector<string> multi_line_fields = {"description"};
|
||||
string new_api_defs_str = api_def.DebugString();
|
||||
return PBTxtToMultiline(new_api_defs_str, multi_line_fields);
|
||||
}
|
||||
|
||||
// Creates ApiDef files for any new ops.
|
||||
// If op_file_pattern is not empty, then also removes .Doc calls from
|
||||
// new op registrations in these files.
|
||||
void CreateApiDefs(const OpList& ops, const string& api_def_dir,
|
||||
const string& op_file_pattern) {
|
||||
auto* excluded_ops = GetExcludedOps();
|
||||
std::vector<const OpDef*> new_ops_with_docs;
|
||||
|
||||
for (const auto& op : ops.op()) {
|
||||
if (excluded_ops->find(op.name()) != excluded_ops->end()) {
|
||||
continue;
|
||||
}
|
||||
// Form the expected ApiDef path.
|
||||
string file_path =
|
||||
io::JoinPath(tensorflow::string(api_def_dir), kApiDefFileFormat);
|
||||
file_path = strings::Printf(file_path.c_str(), op.name().c_str());
|
||||
|
||||
// Create ApiDef if it doesn't exist.
|
||||
if (!Env::Default()->FileExists(file_path).ok()) {
|
||||
std::cout << "Creating ApiDef file " << file_path << std::endl;
|
||||
const auto& api_def_text = CreateApiDef(op);
|
||||
TF_CHECK_OK(WriteStringToFile(Env::Default(), file_path, api_def_text));
|
||||
|
||||
if (OpHasDocs(op)) {
|
||||
new_ops_with_docs.push_back(&op);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!op_file_pattern.empty()) {
|
||||
std::vector<string> op_files;
|
||||
TF_CHECK_OK(Env::Default()->GetMatchingPaths(op_file_pattern, &op_files));
|
||||
RemoveDocs(new_ops_with_docs, op_files);
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
45
tensorflow/core/api_def/update_api_def.h
Normal file
45
tensorflow/core/api_def/update_api_def.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
|
||||
// Functions for updating ApiDef when new ops are added.
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns ApiDef text representation in multi-line format
|
||||
// constructed based on the given op.
|
||||
string CreateApiDef(const OpDef& op);
|
||||
|
||||
// Removes .Doc call for the given op.
|
||||
// If unsuccessful, returns original file_contents and prints an error.
|
||||
// start_location - We search for .Doc call starting at this location
|
||||
// in file_contents.
|
||||
string RemoveDoc(const OpDef& op, const string& file_contents,
|
||||
size_t start_location);
|
||||
|
||||
// Creates api_def_*.pbtxt files for any new ops (i.e. ops that don't have an
|
||||
// api_def_*.pbtxt file yet).
|
||||
// If op_file_pattern is non-empty, then this method will also
|
||||
// look for a REGISTER_OP call for the new ops and removes corresponding
|
||||
// .Doc() calls since the newly generated api_def_*.pbtxt files will
|
||||
// store the doc strings.
|
||||
void CreateApiDefs(const OpList& ops, const string& api_def_dir,
|
||||
const string& op_file_pattern);
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
|
@ -14,15 +14,15 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Script to update tensorflow/core/api_def/base_api/api_def*.pbtxt files.
|
||||
# Script to create tensorflow/core/api_def/base_api/api_def*.pbtxt
|
||||
# files for new ops.
|
||||
|
||||
set -e
|
||||
|
||||
current_file="$(readlink -f "$0")"
|
||||
current_dir="$(dirname "$current_file")"
|
||||
|
||||
bazel build //tensorflow/core:api_test
|
||||
bazel-bin/tensorflow/core/api_test \
|
||||
--update_api_def \
|
||||
--api_def_dir="${current_dir}/base_api"
|
||||
|
||||
bazel build //tensorflow/core/api_def:update_api_def
|
||||
bazel-bin/tensorflow/core/api_def/update_api_def \
|
||||
--api_def_dir="${current_dir}/base_api" \
|
||||
--op_file_pattern="${current_dir}/../ops/*_ops.cc"
|
||||
|
56
tensorflow/core/api_def/update_api_def_main.cc
Normal file
56
tensorflow/core/api_def/update_api_def_main.cc
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This program can be used to automatically create an api_def_*.pbtxt
|
||||
// file based on op definition.
|
||||
//
|
||||
// To run, use the following script:
|
||||
// tensorflow/core/api_def/update_api_def.sh
|
||||
//
|
||||
// There are 2 ways to use this script:
|
||||
// 1. Define a REGISTER_OP call without a .Doc() call. Then, run
|
||||
// this script and add summaries and descriptions in the generated
|
||||
// api_def_*.pbtxt file manually.
|
||||
// 2. Add .Doc() call to a REGISTER_OP call. Then run this script
|
||||
// to remove that .Doc() call and instead add corresponding summaries
|
||||
// and descriptions in api_def_*.pbtxt file automatically.
|
||||
// Note that .Doc() call must have the following format for this to work:
|
||||
// .Doc(R"doc(<doc goes here>)doc").
|
||||
#include "tensorflow/core/api_def/update_api_def.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::string api_files_dir;
|
||||
tensorflow::string op_file_pattern;
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("api_def_dir", &api_files_dir,
|
||||
"Base directory of api_def*.pbtxt files."),
|
||||
tensorflow::Flag("op_file_pattern", &op_file_pattern,
|
||||
"Pattern that matches C++ files containing REGISTER_OP "
|
||||
"calls. If specified, we will try to remove .Doc() "
|
||||
"calls for new ops defined in these files.")};
|
||||
std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parsed_values_ok) {
|
||||
std::cerr << usage << std::endl;
|
||||
return 2;
|
||||
}
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
tensorflow::OpList ops;
|
||||
tensorflow::OpRegistry::Global()->Export(false, &ops);
|
||||
tensorflow::CreateApiDefs(ops, api_files_dir, op_file_pattern);
|
||||
}
|
205
tensorflow/core/api_def/update_api_def_test.cc
Normal file
205
tensorflow/core/api_def/update_api_def_test.cc
Normal file
@ -0,0 +1,205 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/api_def/update_api_def.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UpdateApiDefTest, TestRemoveDocSingleOp) {
|
||||
const string op_def_text = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.Output("output: T")
|
||||
.Attr("b: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
)opdef";
|
||||
|
||||
const string op_def_text_with_doc = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.Output("output: T")
|
||||
.Attr("b: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Summary for Op1.
|
||||
|
||||
Description
|
||||
for Op1.
|
||||
|
||||
b : Description for b.
|
||||
a: Description for a.
|
||||
output: Description for output.
|
||||
)doc");
|
||||
)opdef";
|
||||
|
||||
const string op_text = R"(
|
||||
name: "Op1"
|
||||
input_arg {
|
||||
name: "a"
|
||||
description: "Description for a."
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Description for output."
|
||||
}
|
||||
attr {
|
||||
name: "b"
|
||||
description: "Description for b."
|
||||
}
|
||||
summary: "Summary for Op1."
|
||||
description: "Description\nfor Op1."
|
||||
)";
|
||||
OpDef op;
|
||||
protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT
|
||||
|
||||
EXPECT_EQ(op_def_text,
|
||||
RemoveDoc(op, op_def_text_with_doc, 0 /* start_location */));
|
||||
}
|
||||
|
||||
TEST(UpdateApiDefTest, TestRemoveDocMultipleOps) {
|
||||
const string op_def_text = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("Op2")
|
||||
.Input("a: T")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("Op3")
|
||||
.Input("c: T")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
)opdef";
|
||||
|
||||
const string op_def_text_with_doc = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.Doc(R"doc(
|
||||
Summary for Op1.
|
||||
)doc")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("Op2")
|
||||
.Input("a: T")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Summary for Op2.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Op3")
|
||||
.Input("c: T")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Summary for Op3.
|
||||
)doc");
|
||||
)opdef";
|
||||
|
||||
const string op1_text = R"(
|
||||
name: "Op1"
|
||||
input_arg {
|
||||
name: "a"
|
||||
}
|
||||
summary: "Summary for Op1."
|
||||
)";
|
||||
const string op2_text = R"(
|
||||
name: "Op2"
|
||||
input_arg {
|
||||
name: "a"
|
||||
}
|
||||
summary: "Summary for Op2."
|
||||
)";
|
||||
const string op3_text = R"(
|
||||
name: "Op3"
|
||||
input_arg {
|
||||
name: "c"
|
||||
}
|
||||
summary: "Summary for Op3."
|
||||
)";
|
||||
OpDef op1, op2, op3;
|
||||
protobuf::TextFormat::ParseFromString(op1_text, &op1); // NOLINT
|
||||
protobuf::TextFormat::ParseFromString(op2_text, &op2); // NOLINT
|
||||
protobuf::TextFormat::ParseFromString(op3_text, &op3); // NOLINT
|
||||
|
||||
string updated_text =
|
||||
RemoveDoc(op2, op_def_text_with_doc,
|
||||
op_def_text_with_doc.find("Op2") /* start_location */);
|
||||
EXPECT_EQ(string::npos, updated_text.find("Summary for Op2"));
|
||||
EXPECT_NE(string::npos, updated_text.find("Summary for Op1"));
|
||||
EXPECT_NE(string::npos, updated_text.find("Summary for Op3"));
|
||||
|
||||
updated_text = RemoveDoc(op3, updated_text,
|
||||
updated_text.find("Op3") /* start_location */);
|
||||
updated_text = RemoveDoc(op1, updated_text,
|
||||
updated_text.find("Op1") /* start_location */);
|
||||
EXPECT_EQ(op_def_text, updated_text);
|
||||
}
|
||||
|
||||
TEST(UpdateApiDefTest, TestCreateApiDef) {
|
||||
const string op_text = R"(
|
||||
name: "Op1"
|
||||
input_arg {
|
||||
name: "a"
|
||||
description: "Description for a."
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Description for output."
|
||||
}
|
||||
attr {
|
||||
name: "b"
|
||||
description: "Description for b."
|
||||
}
|
||||
summary: "Summary for Op1."
|
||||
description: "Description\nfor Op1."
|
||||
)";
|
||||
OpDef op;
|
||||
protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT
|
||||
|
||||
const string expected_api_def = R"(graph_op_name: "Op1"
|
||||
in_arg {
|
||||
name: "a"
|
||||
description: <<END
|
||||
Description for a.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Description for output.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "b"
|
||||
description: <<END
|
||||
Description for b.
|
||||
END
|
||||
}
|
||||
summary: "Summary for Op1."
|
||||
description: <<END
|
||||
Description
|
||||
for Op1.
|
||||
END
|
||||
)";
|
||||
EXPECT_EQ(expected_api_def, CreateApiDef(op));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -28,7 +28,7 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
|
||||
require_shape_functions=require_shape_functions,
|
||||
generated_target_name=name,
|
||||
api_def_srcs = [
|
||||
"//tensorflow/core:base_api_def",
|
||||
"//tensorflow/core:python_api_def",
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
"//tensorflow/core/api_def:python_api_def",
|
||||
],
|
||||
)
|
||||
|
@ -18,8 +18,8 @@ py_test(
|
||||
srcs = ["api_compatibility_test.py"],
|
||||
data = [
|
||||
":convert_from_multiline",
|
||||
"//tensorflow/core:base_api_def",
|
||||
"//tensorflow/core:python_api_def",
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
"//tensorflow/core/api_def:python_api_def",
|
||||
"//tensorflow/python:hidden_ops",
|
||||
"//tensorflow/tools/api/golden:api_golden",
|
||||
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
|
||||
|
Loading…
Reference in New Issue
Block a user