Script to create ApiDef files automatically based on OpDef.

PiperOrigin-RevId: 182463327
This commit is contained in:
Anna R 2018-01-18 17:16:06 -08:00 committed by TensorFlower Gardener
parent 9f07ca8df9
commit 547d70e46c
14 changed files with 759 additions and 52 deletions

View File

@ -574,6 +574,7 @@ filegroup(
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/verbs:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/api_def:all_files",
"//tensorflow/core/debug:all_files",
"//tensorflow/core/distributed_runtime:all_files",
"//tensorflow/core/distributed_runtime/rpc:all_files",

View File

@ -421,7 +421,7 @@ tf_cc_test(
tf_gen_op_wrappers_cc(
name = "cc_ops",
api_def_srcs = ["//tensorflow/core:base_api_def"],
api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
op_lib_names = [
"array_ops",
"audio_ops",
@ -526,7 +526,7 @@ cc_library_with_android_deps(
],
copts = tf_copts(),
data = [
"//tensorflow/core:base_api_def",
"//tensorflow/core/api_def:base_api_def",
],
deps = [
"//tensorflow/core:framework",

View File

@ -3456,36 +3456,6 @@ tf_cc_test(
],
)
filegroup(
name = "base_api_def",
srcs = glob(["api_def/base_api/*"]),
)
filegroup(
name = "python_api_def",
srcs = glob(["api_def/python_api/*"]),
)
tf_cc_test(
name = "api_test",
srcs = ["api_def/api_test.cc"],
data = [
":base_api_def",
],
deps = [
":framework",
":framework_internal",
":lib",
":lib_internal",
":lib_test_internal",
":op_gen_lib",
":ops",
":protos_all_cc",
":test",
":test_main",
],
)
tf_cc_test_gpu(
name = "device_tracer_test",
size = "small",

View File

@ -0,0 +1,113 @@
# Description:
# Provides ApiDef access and ApiDef validation for TensorFlow.
#
# The following targets can be used to access ApiDefs:
# :base_api_def
# :python_api_def
package(
default_visibility = ["//visibility:private"],
)
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
filegroup(
name = "base_api_def",
srcs = glob(["base_api/*"]),
visibility = ["//tensorflow:internal"],
)
filegroup(
name = "python_api_def",
srcs = glob(["python_api/*"]),
visibility = ["//tensorflow:internal"],
)
cc_library(
name = "excluded_ops_lib",
srcs = ["excluded_ops.cc"],
hdrs = ["excluded_ops.h"],
)
cc_library(
name = "update_api_def_lib",
srcs = ["update_api_def.cc"],
hdrs = ["update_api_def.h"],
deps = [
":excluded_ops_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "update_api_def_test",
srcs = ["update_api_def_test.cc"],
deps = [
":update_api_def_lib",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_binary(
name = "update_api_def",
srcs = [
"update_api_def_main.cc",
],
data = [
":base_api_def",
],
deps = [
":update_api_def_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "api_test",
srcs = ["api_test.cc"],
data = [
":base_api_def",
],
deps = [
":excluded_ops_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_test_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -21,8 +21,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/api_def/excluded_ops.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
@ -44,15 +44,6 @@ constexpr char kDefaultApiDefDir[] =
constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt";
} // namespace
// Returns a list of ops excluded from ApiDef.
// TODO(annarev): figure out if we should keep ApiDefs for these ops as well.
const std::unordered_set<string>* GetExcludedOps() {
static std::unordered_set<string>* excluded_ops =
new std::unordered_set<string>(
{"BigQueryReader", "GenerateBigQueryReaderPartitions"});
return excluded_ops;
}
// Reads golden ApiDef files and returns a map from file name to ApiDef file
// contents.
void GetGoldenApiDefs(Env* env, const string& api_files_dir,

View File

@ -0,0 +1,26 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/api_def/excluded_ops.h"
namespace tensorflow {
const std::unordered_set<std::string>* GetExcludedOps() {
static std::unordered_set<std::string>* excluded_ops =
new std::unordered_set<std::string>(
{"BigQueryReader", "GenerateBigQueryReaderPartitions"});
return excluded_ops;
}
} // namespace tensorflow

View File

@ -0,0 +1,28 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_
#define TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_
#include <string>
#include <unordered_set>
namespace tensorflow {
// Returns a list of ops excluded from ApiDef.
// TODO(annarev): figure out if we should keep ApiDefs for these ops as well
const std::unordered_set<std::string>* GetExcludedOps();
} // namespace tensorflow
#endif // TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_

View File

@ -0,0 +1,272 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/api_def/update_api_def.h"
#include <ctype.h>
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/core/api_def/excluded_ops.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
namespace {
constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt";
// TODO(annarev): look into supporting other prefixes, not just 'doc'.
constexpr char kDocStart[] = ".Doc(R\"doc(";
constexpr char kDocEnd[] = ")doc\")";
// Updates api_def based on the given op.
void FillBaseApiDef(ApiDef* api_def, const OpDef& op) {
api_def->set_graph_op_name(op.name());
// Add arg docs
for (auto& input_arg : op.input_arg()) {
if (!input_arg.description().empty()) {
auto* api_def_in_arg = api_def->add_in_arg();
api_def_in_arg->set_name(input_arg.name());
api_def_in_arg->set_description(input_arg.description());
}
}
for (auto& output_arg : op.output_arg()) {
if (!output_arg.description().empty()) {
auto* api_def_out_arg = api_def->add_out_arg();
api_def_out_arg->set_name(output_arg.name());
api_def_out_arg->set_description(output_arg.description());
}
}
// Add attr docs
for (auto& attr : op.attr()) {
if (!attr.description().empty()) {
auto* api_def_attr = api_def->add_attr();
api_def_attr->set_name(attr.name());
api_def_attr->set_description(attr.description());
}
}
// Add docs
api_def->set_summary(op.summary());
api_def->set_description(op.description());
}
// Returns true if op has any description or summary.
bool OpHasDocs(const OpDef& op) {
if (!op.summary().empty() || !op.description().empty()) {
return true;
}
for (const auto& arg : op.input_arg()) {
if (!arg.description().empty()) {
return true;
}
}
for (const auto& arg : op.output_arg()) {
if (!arg.description().empty()) {
return true;
}
}
for (const auto& attr : op.attr()) {
if (!attr.description().empty()) {
return true;
}
}
return false;
}
// Returns true if summary and all descriptions are the same in op1
// and op2.
bool CheckDocsMatch(const OpDef& op1, const OpDef& op2) {
if (op1.summary() != op2.summary() ||
op1.description() != op2.description() ||
op1.input_arg_size() != op2.input_arg_size() ||
op1.output_arg_size() != op2.output_arg_size() ||
op1.attr_size() != op2.attr_size()) {
return false;
}
// Iterate over args and attrs to compare their docs.
for (int i = 0; i < op1.input_arg_size(); ++i) {
if (op1.input_arg(i).description() != op2.input_arg(i).description()) {
return false;
}
}
for (int i = 0; i < op1.output_arg_size(); ++i) {
if (op1.output_arg(i).description() != op2.output_arg(i).description()) {
return false;
}
}
for (int i = 0; i < op1.attr_size(); ++i) {
if (op1.attr(i).description() != op2.attr(i).description()) {
return false;
}
}
return true;
}
// Returns true if descriptions and summaries in op match a
// given single doc-string.
bool ValidateOpDocs(const OpDef& op, const string& doc) {
OpDefBuilder b(op.name());
// We don't really care about type we use for arguments and
// attributes. We just want to make sure attribute and argument names
// are added so that descriptions can be assigned to them when parsing
// documentation.
for (const auto& arg : op.input_arg()) {
b.Input(arg.name() + ":string");
}
for (const auto& arg : op.output_arg()) {
b.Output(arg.name() + ":string");
}
for (const auto& attr : op.attr()) {
b.Attr(attr.name() + ":string");
}
b.Doc(doc);
OpRegistrationData op_reg_data;
TF_CHECK_OK(b.Finalize(&op_reg_data));
return CheckDocsMatch(op, op_reg_data.op_def);
}
} // namespace
string RemoveDoc(const OpDef& op, const string& file_contents,
size_t start_location) {
// Look for a line starting with .Doc( after the REGISTER_OP.
const auto doc_start_location = file_contents.find(kDocStart, start_location);
const string format_error = strings::Printf(
"Could not find %s doc for removal. Make sure the doc is defined with "
"'%s' prefix and '%s' suffix or remove the doc manually.",
op.name().c_str(), kDocStart, kDocEnd);
if (doc_start_location == string::npos) {
std::cerr << format_error << std::endl;
LOG(ERROR) << "Didn't find doc start";
return file_contents;
}
const auto doc_end_location = file_contents.find(kDocEnd, doc_start_location);
if (doc_end_location == string::npos) {
LOG(ERROR) << "Didn't find doc start";
std::cerr << format_error << std::endl;
return file_contents;
}
const auto doc_start_size = sizeof(kDocStart) - 1;
string doc_text = file_contents.substr(
doc_start_location + doc_start_size,
doc_end_location - doc_start_location - doc_start_size);
// Make sure the doc text we found actually matches OpDef docs to
// avoid removing incorrect text.
if (!ValidateOpDocs(op, doc_text)) {
LOG(ERROR) << "Invalid doc: " << doc_text;
std::cerr << format_error << std::endl;
return file_contents;
}
// Remove .Doc call.
auto before_doc = file_contents.substr(0, doc_start_location);
str_util::StripTrailingWhitespace(&before_doc);
return before_doc +
file_contents.substr(doc_end_location + sizeof(kDocEnd) - 1);
}
namespace {
// Remove .Doc calls that follow REGISTER_OP calls for the given ops.
// We search for REGISTER_OP calls in the given op_files list.
void RemoveDocs(const std::vector<const OpDef*>& ops,
const std::vector<string>& op_files) {
// Set of ops that we already found REGISTER_OP calls for.
std::set<string> processed_ops;
for (const auto& file : op_files) {
string file_contents;
bool file_contents_updated = false;
TF_CHECK_OK(ReadFileToString(Env::Default(), file, &file_contents));
for (auto op : ops) {
if (processed_ops.find(op->name()) != processed_ops.end()) {
// We already found REGISTER_OP call for this op in another file.
continue;
}
string register_call =
strings::Printf("REGISTER_OP(\"%s\")", op->name().c_str());
const auto register_call_location = file_contents.find(register_call);
// Find REGISTER_OP(OpName) call.
if (register_call_location == string::npos) {
continue;
}
std::cout << "Removing .Doc call for " << op->name() << " from " << file
<< "." << std::endl;
file_contents = RemoveDoc(*op, file_contents, register_call_location);
file_contents_updated = true;
processed_ops.insert(op->name());
}
if (file_contents_updated) {
TF_CHECK_OK(WriteStringToFile(Env::Default(), file, file_contents))
<< "Could not remove .Doc calls in " << file
<< ". Make sure the file is writable.";
}
}
}
} // namespace
// Returns ApiDef text representation in multi-line format
// constructed based on the given op.
string CreateApiDef(const OpDef& op) {
ApiDef api_def;
FillBaseApiDef(&api_def, op);
const std::vector<string> multi_line_fields = {"description"};
string new_api_defs_str = api_def.DebugString();
return PBTxtToMultiline(new_api_defs_str, multi_line_fields);
}
// Creates ApiDef files for any new ops.
// If op_file_pattern is not empty, then also removes .Doc calls from
// new op registrations in these files.
void CreateApiDefs(const OpList& ops, const string& api_def_dir,
const string& op_file_pattern) {
auto* excluded_ops = GetExcludedOps();
std::vector<const OpDef*> new_ops_with_docs;
for (const auto& op : ops.op()) {
if (excluded_ops->find(op.name()) != excluded_ops->end()) {
continue;
}
// Form the expected ApiDef path.
string file_path =
io::JoinPath(tensorflow::string(api_def_dir), kApiDefFileFormat);
file_path = strings::Printf(file_path.c_str(), op.name().c_str());
// Create ApiDef if it doesn't exist.
if (!Env::Default()->FileExists(file_path).ok()) {
std::cout << "Creating ApiDef file " << file_path << std::endl;
const auto& api_def_text = CreateApiDef(op);
TF_CHECK_OK(WriteStringToFile(Env::Default(), file_path, api_def_text));
if (OpHasDocs(op)) {
new_ops_with_docs.push_back(&op);
}
}
}
if (!op_file_pattern.empty()) {
std::vector<string> op_files;
TF_CHECK_OK(Env::Default()->GetMatchingPaths(op_file_pattern, &op_files));
RemoveDocs(new_ops_with_docs, op_files);
}
}
} // namespace tensorflow

View File

@ -0,0 +1,45 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
#define THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
// Functions for updating ApiDef when new ops are added.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Returns ApiDef text representation in multi-line format
// constructed based on the given op.
string CreateApiDef(const OpDef& op);
// Removes .Doc call for the given op.
// If unsuccessful, returns original file_contents and prints an error.
// start_location - We search for .Doc call starting at this location
// in file_contents.
string RemoveDoc(const OpDef& op, const string& file_contents,
size_t start_location);
// Creates api_def_*.pbtxt files for any new ops (i.e. ops that don't have an
// api_def_*.pbtxt file yet).
// If op_file_pattern is non-empty, then this method will also
// look for a REGISTER_OP call for the new ops and removes corresponding
// .Doc() calls since the newly generated api_def_*.pbtxt files will
// store the doc strings.
void CreateApiDefs(const OpList& ops, const string& api_def_dir,
const string& op_file_pattern);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_

View File

@ -14,15 +14,15 @@
# limitations under the License.
# ==============================================================================
# Script to update tensorflow/core/api_def/base_api/api_def*.pbtxt files.
# Script to create tensorflow/core/api_def/base_api/api_def*.pbtxt
# files for new ops.
set -e
current_file="$(readlink -f "$0")"
current_dir="$(dirname "$current_file")"
bazel build //tensorflow/core:api_test
bazel-bin/tensorflow/core/api_test \
--update_api_def \
--api_def_dir="${current_dir}/base_api"
bazel build //tensorflow/core/api_def:update_api_def
bazel-bin/tensorflow/core/api_def/update_api_def \
--api_def_dir="${current_dir}/base_api" \
--op_file_pattern="${current_dir}/../ops/*_ops.cc"

View File

@ -0,0 +1,56 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This program can be used to automatically create an api_def_*.pbtxt
// file based on op definition.
//
// To run, use the following script:
// tensorflow/core/api_def/update_api_def.sh
//
// There are 2 ways to use this script:
// 1. Define a REGISTER_OP call without a .Doc() call. Then, run
// this script and add summaries and descriptions in the generated
// api_def_*.pbtxt file manually.
// 2. Add .Doc() call to a REGISTER_OP call. Then run this script
// to remove that .Doc() call and instead add corresponding summaries
// and descriptions in api_def_*.pbtxt file automatically.
// Note that .Doc() call must have the following format for this to work:
// .Doc(R"doc(<doc goes here>)doc").
#include "tensorflow/core/api_def/update_api_def.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
int main(int argc, char** argv) {
tensorflow::string api_files_dir;
tensorflow::string op_file_pattern;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("api_def_dir", &api_files_dir,
"Base directory of api_def*.pbtxt files."),
tensorflow::Flag("op_file_pattern", &op_file_pattern,
"Pattern that matches C++ files containing REGISTER_OP "
"calls. If specified, we will try to remove .Doc() "
"calls for new ops defined in these files.")};
std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parsed_values_ok) {
std::cerr << usage << std::endl;
return 2;
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
tensorflow::OpList ops;
tensorflow::OpRegistry::Global()->Export(false, &ops);
tensorflow::CreateApiDefs(ops, api_files_dir, op_file_pattern);
}

View File

@ -0,0 +1,205 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/api_def/update_api_def.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(UpdateApiDefTest, TestRemoveDocSingleOp) {
const string op_def_text = R"opdef(
REGISTER_OP("Op1")
.Input("a: T")
.Output("output: T")
.Attr("b: type")
.SetShapeFn(shape_inference::UnchangedShape);
)opdef";
const string op_def_text_with_doc = R"opdef(
REGISTER_OP("Op1")
.Input("a: T")
.Output("output: T")
.Attr("b: type")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Summary for Op1.
Description
for Op1.
b : Description for b.
a: Description for a.
output: Description for output.
)doc");
)opdef";
const string op_text = R"(
name: "Op1"
input_arg {
name: "a"
description: "Description for a."
}
output_arg {
name: "output"
description: "Description for output."
}
attr {
name: "b"
description: "Description for b."
}
summary: "Summary for Op1."
description: "Description\nfor Op1."
)";
OpDef op;
protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT
EXPECT_EQ(op_def_text,
RemoveDoc(op, op_def_text_with_doc, 0 /* start_location */));
}
TEST(UpdateApiDefTest, TestRemoveDocMultipleOps) {
const string op_def_text = R"opdef(
REGISTER_OP("Op1")
.Input("a: T")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Op2")
.Input("a: T")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Op3")
.Input("c: T")
.SetShapeFn(shape_inference::UnchangedShape);
)opdef";
const string op_def_text_with_doc = R"opdef(
REGISTER_OP("Op1")
.Input("a: T")
.Doc(R"doc(
Summary for Op1.
)doc")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Op2")
.Input("a: T")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Summary for Op2.
)doc");
REGISTER_OP("Op3")
.Input("c: T")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Summary for Op3.
)doc");
)opdef";
const string op1_text = R"(
name: "Op1"
input_arg {
name: "a"
}
summary: "Summary for Op1."
)";
const string op2_text = R"(
name: "Op2"
input_arg {
name: "a"
}
summary: "Summary for Op2."
)";
const string op3_text = R"(
name: "Op3"
input_arg {
name: "c"
}
summary: "Summary for Op3."
)";
OpDef op1, op2, op3;
protobuf::TextFormat::ParseFromString(op1_text, &op1); // NOLINT
protobuf::TextFormat::ParseFromString(op2_text, &op2); // NOLINT
protobuf::TextFormat::ParseFromString(op3_text, &op3); // NOLINT
string updated_text =
RemoveDoc(op2, op_def_text_with_doc,
op_def_text_with_doc.find("Op2") /* start_location */);
EXPECT_EQ(string::npos, updated_text.find("Summary for Op2"));
EXPECT_NE(string::npos, updated_text.find("Summary for Op1"));
EXPECT_NE(string::npos, updated_text.find("Summary for Op3"));
updated_text = RemoveDoc(op3, updated_text,
updated_text.find("Op3") /* start_location */);
updated_text = RemoveDoc(op1, updated_text,
updated_text.find("Op1") /* start_location */);
EXPECT_EQ(op_def_text, updated_text);
}
TEST(UpdateApiDefTest, TestCreateApiDef) {
const string op_text = R"(
name: "Op1"
input_arg {
name: "a"
description: "Description for a."
}
output_arg {
name: "output"
description: "Description for output."
}
attr {
name: "b"
description: "Description for b."
}
summary: "Summary for Op1."
description: "Description\nfor Op1."
)";
OpDef op;
protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT
const string expected_api_def = R"(graph_op_name: "Op1"
in_arg {
name: "a"
description: <<END
Description for a.
END
}
out_arg {
name: "output"
description: <<END
Description for output.
END
}
attr {
name: "b"
description: <<END
Description for b.
END
}
summary: "Summary for Op1."
description: <<END
Description
for Op1.
END
)";
EXPECT_EQ(expected_api_def, CreateApiDef(op));
}
} // namespace
} // namespace tensorflow

View File

@ -28,7 +28,7 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
require_shape_functions=require_shape_functions,
generated_target_name=name,
api_def_srcs = [
"//tensorflow/core:base_api_def",
"//tensorflow/core:python_api_def",
"//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:python_api_def",
],
)

View File

@ -18,8 +18,8 @@ py_test(
srcs = ["api_compatibility_test.py"],
data = [
":convert_from_multiline",
"//tensorflow/core:base_api_def",
"//tensorflow/core:python_api_def",
"//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:python_api_def",
"//tensorflow/python:hidden_ops",
"//tensorflow/tools/api/golden:api_golden",
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",