Provide an option to use ApiDef instead of OpGenOverrides when generating C++ API. Also, updating UpdateDocs method to ApiDef to replace names in docs.

PiperOrigin-RevId: 176167953
This commit is contained in:
Anna R 2017-11-17 15:20:49 -08:00 committed by TensorFlower Gardener
parent 3cc43816cd
commit cb12ebe044
28 changed files with 895 additions and 115 deletions

View File

@ -421,6 +421,7 @@ tf_cc_test(
tf_gen_op_wrappers_cc(
name = "cc_ops",
api_def_srcs = ["//tensorflow/core:base_api_def"],
op_lib_names = [
"array_ops",
"audio_ops",
@ -525,6 +526,9 @@ cc_library_with_android_deps(
"//tensorflow/core:android_tensorflow_lib",
],
copts = tf_copts(),
data = [
"//tensorflow/core:base_api_def",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -536,6 +540,29 @@ cc_library_with_android_deps(
],
)
tf_cc_test(
name = "cc_op_gen_test",
srcs = [
"framework/cc_op_gen.cc",
"framework/cc_op_gen.h",
"framework/cc_op_gen_test.cc",
],
data = [
"//tensorflow/cc:ops/op_gen_overrides.pbtxt",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:op_gen_overrides_proto_cc",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "test_op_op_lib",
srcs = ["framework/test_op.cc"],

View File

@ -18,8 +18,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
@ -385,10 +387,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) {
}
bool HasOptionalAttrs(
const OpDef& op_def,
const ApiDef& api_def,
const std::unordered_map<string, string>& inferred_input_attrs) {
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < api_def.attr_size(); ++i) {
const auto& attr(api_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) ==
inferred_input_attrs.end()) &&
attr.has_default_value()) {
@ -398,12 +400,21 @@ bool HasOptionalAttrs(
return false;
}
const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
for (int i = 0; i < api_def.in_arg_size(); ++i) {
if (api_def.in_arg(i).name() == name) {
return &api_def.in_arg(i);
}
}
return nullptr;
}
struct OpInfo {
// graph_op_def: The OpDef used by the runtime, has the names that
// must be used when calling NodeBuilder.
// interface_op_def: The OpDef used in the interface in the generated
// code, with possibly overridden names and defaults.
explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases);
string GetOpAttrStruct() const;
string GetConstructorDecl(StringPiece op_name_prefix,
@ -423,74 +434,81 @@ struct OpInfo {
string comment;
const OpDef& graph_op_def;
const OpDef& op_def;
const ApiDef& api_def;
const std::vector<string>& aliases;
// Map from type attribute to corresponding original argument name.
std::unordered_map<string, string> inferred_input_attrs;
};
OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
const std::vector<string>& a)
: graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
op_name = op_def.name();
InferOpAttributes(op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases)
: graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
op_name = api_def.endpoint(0).name();
InferOpAttributes(graph_op_def, &inferred_input_attrs);
has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
arg_types.push_back("const ::tensorflow::Scope&");
arg_names.push_back("scope");
if (op_def.has_deprecation()) {
if (!op_def.summary().empty()) {
comment = strings::StrCat(op_def.summary(), "\n");
if (graph_op_def.has_deprecation()) {
if (!api_def.summary().empty()) {
comment = strings::StrCat(api_def.summary(), "\n");
}
strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
op_def.deprecation().version(), ":\n",
op_def.deprecation().explanation(), ".\n");
} else if (op_def.summary().empty()) {
graph_op_def.deprecation().version(), ":\n",
graph_op_def.deprecation().explanation(), ".\n");
} else if (api_def.summary().empty()) {
comment = "TODO: add doc.\n";
} else {
comment = strings::StrCat(op_def.summary(), "\n");
comment = strings::StrCat(api_def.summary(), "\n");
}
if (!op_def.description().empty()) {
strings::StrAppend(&comment, "\n", op_def.description(), "\n");
if (!api_def.description().empty()) {
strings::StrAppend(&comment, "\n", api_def.description(), "\n");
}
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
// Process inputs
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
for (int i = 0; i < api_def.arg_order_size(); ++i) {
const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
arg_types.push_back(strings::StrCat(
"::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
arg_names.push_back(AvoidCPPKeywords(arg.name()));
arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
// TODO(keveman): Include input type information.
StringPiece description = arg.description();
StringPiece description = api_def_arg.description();
if (!description.empty()) {
ConsumeEquals(&description);
strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
arg.description(), "\n");
strings::StrAppend(&comment, "* ",
AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
api_def_arg.description(), "\n");
}
}
// Process attrs
string required_attrs_comment;
string optional_attrs_comment;
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
// ApiDef attributes must be in the same order as in OpDef since
// we initialize ApiDef based on OpDef.
const auto& attr(graph_op_def.attr(i));
const auto& api_def_attr(api_def.attr(i));
CHECK_EQ(attr.name(), api_def_attr.name());
// Skip inferred arguments
if (inferred_input_attrs.count(attr.name()) > 0) continue;
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
string attr_name = AvoidCPPKeywords(attr.name());
string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
string attr_comment;
if (!attr.description().empty()) {
if (!api_def_attr.description().empty()) {
// TODO(keveman): Word wrap and indent this, to handle multi-line
// descriptions.
strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
attr.description(), "\n");
api_def_attr.description(), "\n");
}
if (attr.has_default_value()) {
if (api_def_attr.has_default_value()) {
strings::StrAppend(&optional_attrs_comment, attr_comment);
} else {
strings::StrAppend(&required_attrs_comment, attr_comment);
@ -508,44 +526,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
}
// Process outputs
for (int i = 0; i < op_def.output_arg_size(); ++i) {
const auto& arg = op_def.output_arg(i);
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
// ApiDef arguments must be in the same order as in OpDef since
// we initialize ApiDef based on OpDef.
const auto& arg = graph_op_def.output_arg(i);
const auto& api_def_arg(api_def.out_arg(i));
CHECK_EQ(arg.name(), api_def_arg.name());
bool is_list = ArgIsList(arg);
output_types.push_back(
strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
output_names.push_back(AvoidCPPKeywords(arg.name()));
output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
is_list_output.push_back(is_list);
}
strings::StrAppend(&comment, "\nReturns:\n");
if (op_def.output_arg_size() == 0) { // No outputs.
if (graph_op_def.output_arg_size() == 0) { // No outputs.
strings::StrAppend(&comment, "* the created `Operation`\n");
} else if (op_def.output_arg_size() == 1) { // One output
} else if (graph_op_def.output_arg_size() == 1) { // One output
if (is_list_output[0]) {
strings::StrAppend(&comment, "* `OutputList`: ");
} else {
strings::StrAppend(&comment, "* `Output`: ");
}
if (op_def.output_arg(0).description().empty()) {
strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
if (api_def.out_arg(0).description().empty()) {
strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
" tensor.\n");
} else {
// TODO(josh11b): Word wrap this.
strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
}
} else { // Multiple outputs.
for (int i = 0; i < op_def.output_arg_size(); ++i) {
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
if (is_list_output[i]) {
strings::StrAppend(&comment, "* `OutputList`");
} else {
strings::StrAppend(&comment, "* `Output`");
}
strings::StrAppend(&comment, " ", output_names[i]);
if (op_def.output_arg(i).description().empty()) {
if (api_def.out_arg(i).description().empty()) {
strings::StrAppend(&comment, "\n");
} else {
// TODO(josh11b): Word wrap this.
strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
"\n");
}
}
@ -564,19 +587,20 @@ string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
const auto& api_def_attr(api_def.attr(i));
// If attr will be inferred or it doesn't have a default value, don't
// add it to the struct.
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
!attr.has_default_value()) {
!api_def_attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
const string camel_case_name = ToCamelCase(attr.name());
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def =
@ -584,22 +608,25 @@ string OpInfo::GetOpAttrStruct() const {
attr_type_name, use_const ? "&" : "");
string attr_comment;
if (!attr.description().empty()) {
strings::StrAppend(&attr_comment, attr.description(), "\n\n");
if (!api_def_attr.description().empty()) {
strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
}
strings::StrAppend(&attr_comment, "Defaults to ",
SummarizeAttrValue(attr.default_value()), "\n");
SummarizeAttrValue(api_def_attr.default_value()), "\n");
attr_comment = MakeComment(attr_comment, " ");
strings::StrAppend(&setters, attr_comment);
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
strings::StrAppend(&setters, " Attrs ret = *this;\n");
strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
"_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
strings::StrAppend(
&struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
&struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
"_ = ",
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
";\n");
}
if (struct_fields.empty()) {
@ -676,17 +703,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
// Add the static functions to set optional attrs
if (has_optional_attrs) {
strings::StrAppend(&class_decl, "\n");
for (int i = 0; i < op_def.attr_size(); ++i) {
const auto& attr(op_def.attr(i));
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
const auto& api_def_attr(api_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
!attr.has_default_value()) {
!api_def_attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
const string camel_case_name = ToCamelCase(attr.name());
const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def = strings::StrCat(
@ -726,11 +754,11 @@ void OpInfo::GetOutput(string* out) const {
strings::StrCat("if (!", scope_str, ".ok()) return;");
// No outputs.
if (op_def.output_arg_size() == 0) {
if (graph_op_def.output_arg_size() == 0) {
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
return;
}
if (op_def.output_arg_size() == 1) {
if (graph_op_def.output_arg_size() == 1) {
// One output, no need for NameRangeMap
if (is_list_output[0]) {
strings::StrAppend(out,
@ -752,7 +780,7 @@ void OpInfo::GetOutput(string* out) const {
".UpdateStatus(_status_);\n", " return;\n");
strings::StrAppend(out, " }\n\n");
for (int i = 0; i < op_def.output_arg_size(); ++i) {
for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
const string arg_range = strings::StrCat(
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
if (is_list_output[i]) {
@ -776,11 +804,13 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", return_on_error, "\n");
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
const auto& arg(graph_op_def.input_arg(i));
const auto& api_def_arg(api_def.in_arg(i));
strings::StrAppend(
&body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
}
@ -791,19 +821,21 @@ string OpInfo::GetConstructorBody() const {
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
graph_op_def.name(), "\")\n");
const string spaces = " ";
for (int i = 0; i < op_def.input_arg_size(); ++i) {
const auto& arg(op_def.input_arg(i));
strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
for (int i = 0; i < api_def.in_arg_size(); ++i) {
const auto& arg(api_def.in_arg(i));
strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
}
for (int i = 0; i < op_def.attr_size(); ++i) {
for (int i = 0; i < api_def.attr_size(); ++i) {
const auto& graph_attr(graph_op_def.attr(i));
const auto& attr(op_def.attr(i));
if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
const auto& api_def_attr(api_def.attr(i));
if (inferred_input_attrs.find(api_def_attr.name()) !=
inferred_input_attrs.end()) {
continue;
}
const string attr_name = attr.has_default_value()
? strings::StrCat("attrs.", attr.name(), "_")
: AvoidCPPKeywords(attr.name());
const string attr_name =
api_def_attr.has_default_value()
? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
: AvoidCPPKeywords(api_def_attr.rename_to());
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
attr_name, ")\n");
}
@ -845,10 +877,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
TF_CHECK_OK(cc->Append(class_def));
}
void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases, WritableFile* h,
WritableFile* cc) {
OpInfo op_info(graph_op_def, interface_op_def, aliases);
OpInfo op_info(graph_op_def, api_def, aliases);
op_info.WriteClassDecl(h);
op_info.WriteClassDef(cc);
@ -943,8 +975,9 @@ string MakeInternal(const string& fname) {
} // namespace
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
const string& dot_cc_fname, const string& overrides_fnames) {
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& dot_h_fname, const string& dot_cc_fname,
const string& overrides_fnames) {
Env* env = Env::Default();
// Load the override map.
@ -984,24 +1017,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname,
// code depends on it.
if (graph_op_def.name() == "Const") continue;
// Incorporate overrides from override_map.
OpDef interface_op_def = graph_op_def;
const OpGenOverride* op_override =
override_map.ApplyOverride(&interface_op_def);
std::vector<string> aliases;
if (op_override) {
if (op_override->skip()) continue;
aliases.assign(op_override->alias().begin(), op_override->alias().end());
if (op_override->hide()) {
// Write hidden ops to _internal.h and _internal.cc.
WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
internal_cc.get());
continue;
}
}
const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
std::vector<string> aliases;
if (api_def->visibility() == ApiDef::SKIP) continue;
// First endpoint is canonical, the rest are aliases.
for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
++endpoint_i) {
aliases.push_back(api_def->endpoint(endpoint_i).name());
}
if (api_def->visibility() == ApiDef::HIDDEN) {
// Write hidden ops to _internal.h and _internal.cc.
WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
internal_cc.get());
continue;
}
// This isn't a hidden op, write it to the main files.
WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
}
FinishFiles(false, h.get(), cc.get(), op_header_guard);

View File

@ -17,13 +17,15 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
/// Result is written to files dot_h and dot_cc.
void WriteCCOps(const OpList& ops, const string& dot_h_fname,
const string& dot_cc_fname, const string& overrides_fnames);
void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
const string& dot_h_fname, const string& dot_cc_fname,
const string& overrides_fnames);
} // namespace tensorflow

View File

@ -16,7 +16,11 @@ limitations under the License.
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/types.h"
@ -24,10 +28,28 @@ namespace tensorflow {
namespace {
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
const std::string& overrides_fnames, bool include_internal) {
const std::string& overrides_fnames, bool include_internal,
const std::vector<string>& api_def_dirs) {
OpList ops;
OpRegistry::Global()->Export(include_internal, &ops);
WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
ApiDefMap api_def_map(ops);
if (!api_def_dirs.empty()) {
Env* env = Env::Default();
// Only load files that correspond to "ops".
for (const auto& op : ops.op()) {
for (const auto& api_def_dir : api_def_dirs) {
const std::string api_def_file_pattern =
io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
if (env->FileExists(api_def_file_pattern).ok()) {
TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern));
}
}
}
}
api_def_map.UpdateDocs();
WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames);
}
} // namespace
@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc != 5) {
// TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt
// as an argument.
if (argc != 6) {
for (int i = 1; i < argc; ++i) {
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
}
fprintf(stderr,
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
"Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal "
"api_def_dirs1,api_def_dir2 ...\n"
" include_internal: 1 means include internal ops\n",
argv[0]);
exit(1);
}
bool include_internal = tensorflow::StringPiece("1") == argv[4];
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
argv[5], ",", tensorflow::str_util::SkipEmpty());
tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal,
api_def_dirs);
return 0;
}

View File

@ -0,0 +1,195 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// TODO(annarev): Remove this op_gen_overrides.pbtxt reference.
// It is needed only because WriteCCOps takes it as an argument.
constexpr char kOverridesFnames[] =
"tensorflow/cc/ops/op_gen_overrides.pbtxt";
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
input_arg {
name: "images"
description: "Images to process."
}
input_arg {
name: "dim"
description: "Description for dim."
type: DT_FLOAT
}
output_arg {
name: "output"
description: "Description for output."
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
description: "Type for images"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
}
}
default_value {
i: 1
}
}
summary: "Summary for op Foo."
description: "Description for op Foo."
}
)";
void ExpectHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(s.contains(expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
EXPECT_FALSE(s.contains(expected))
<< "'" << s << "' contains '" << expected << "'";
}
void ExpectSubstrOrder(const string& s, const string& before,
const string& after) {
int before_pos = s.find(before);
int after_pos = s.find(after);
ASSERT_NE(std::string::npos, before_pos);
ASSERT_NE(std::string::npos, after_pos);
EXPECT_LT(before_pos, after_pos)
<< before << " is not before " << after << " in " << s;
}
// Runs WriteCCOps and stores output in (internal_)cc_file_path and
// (internal_)h_file_path.
void GenerateCcOpFiles(Env* env, const OpList& ops,
const ApiDefMap& api_def_map, string* h_file_text,
string* internal_h_file_text) {
const string& tmpdir = testing::TmpDir();
const auto h_file_path = io::JoinPath(tmpdir, "test.h");
const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames);
TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
TF_ASSERT_OK(
ReadFileToString(env, internal_h_file_path, internal_h_file_text));
}
TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
const string api_def = R"(
op {
graph_op_name: "Foo"
visibility: HIDDEN
}
)";
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
ApiDefMap api_def_map(op_defs);
string h_file_text, internal_h_file_text;
// Without ApiDef
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(h_file_text, "class Foo");
ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
// With ApiDef
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(internal_h_file_text, "class Foo");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
}
TEST(CcOpGenTest, TestArgNameChanges) {
const string api_def = R"(
op {
graph_op_name: "Foo"
arg_order: "dim"
arg_order: "images"
}
)";
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
ApiDefMap api_def_map(op_defs);
string cc_file_text, h_file_text;
string internal_cc_file_text, internal_h_file_text;
// Without ApiDef
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
// With ApiDef
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
}
TEST(CcOpGenTest, TestEndpoints) {
const string api_def = R"(
op {
graph_op_name: "Foo"
endpoint {
name: "Foo1"
}
endpoint {
name: "Foo2"
}
}
)";
Env* env = Env::Default();
OpList op_defs;
protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
ApiDefMap api_def_map(op_defs);
string cc_file_text, h_file_text;
string internal_cc_file_text, internal_h_file_text;
// Without ApiDef
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(h_file_text, "class Foo {");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
// With ApiDef
TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
&internal_h_file_text);
ExpectHasSubstr(h_file_text, "class Foo1");
ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
}
} // namespace
} // namespace tensorflow

View File

@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc
${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h
${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal}
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api
DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir
)

View File

@ -3371,7 +3371,7 @@ tf_cc_test(
filegroup(
name = "base_api_def",
data = glob(["api_def/base_api/*"]),
srcs = glob(["api_def/base_api/*"]),
)
filegroup(
@ -3386,10 +3386,6 @@ tf_cc_test(
":base_api_def",
"//tensorflow/cc:ops/op_gen_overrides.pbtxt",
],
tags = [
"manual",
"notap",
],
deps = [
":framework",
":framework_internal",

View File

@ -221,9 +221,18 @@ std::unordered_map<string, ApiDefs> GenerateApiDef(
std::unordered_map<string, ApiDefs> api_defs_map;
// These ops are included in OpList only if TF_NEED_GCP
// is set to true. So, we skip them for now so that this test passes
// whether TF_NEED_GCP is set or not.
const std::unordered_set<string> ops_to_exclude = {
"BigQueryReader", "GenerateBigQueryReaderPartitions"};
for (const auto& op : ops.op()) {
CHECK(!op.name().empty())
<< "Encountered empty op name: %s" << op.DebugString();
if (ops_to_exclude.find(op.name()) != ops_to_exclude.end()) {
LOG(INFO) << "Skipping " << op.name();
continue;
}
string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat);
file_path = strings::Printf(file_path.c_str(), op.name().c_str());
ApiDef* api_def = api_defs_map[file_path].add_op();

View File

@ -0,0 +1,65 @@
op {
graph_op_name: "ApplyAddSign"
in_arg {
name: "var"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "m"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "lr"
description: <<END
Scaling factor. Must be a scalar.
END
}
in_arg {
name: "alpha"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "sign_decay"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "beta"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "grad"
description: <<END
The gradient.
END
}
out_arg {
name: "out"
description: <<END
Same as "var".
END
}
attr {
name: "use_locking"
description: <<END
If `True`, updating of the var and m tensors is
protected by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
END
}
summary: "Update \'*var\' according to the AddSign update."
description: <<END
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
update <- (alpha + sign_decay * sign(g) *sign(m)) * g
variable <- variable - lr_t * update
END
}

View File

@ -0,0 +1,65 @@
op {
graph_op_name: "ApplyPowerSign"
in_arg {
name: "var"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "m"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "lr"
description: <<END
Scaling factor. Must be a scalar.
END
}
in_arg {
name: "logbase"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "sign_decay"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "beta"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "grad"
description: <<END
The gradient.
END
}
out_arg {
name: "out"
description: <<END
Same as "var".
END
}
attr {
name: "use_locking"
description: <<END
If `True`, updating of the var and m tensors is
protected by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
END
}
summary: "Update \'*var\' according to the AddSign update."
description: <<END
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g
variable <- variable - lr_t * update
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "BytesProducedStatsDataset"
summary: "Records the bytes size of each element of `input_dataset` in a StatsAggregator."
}

View File

@ -0,0 +1,19 @@
op {
graph_op_name: "DeserializeSparse"
in_arg {
name: "serialized_sparse"
description: <<END
1-D, The serialized `SparseTensor` object. Must have 3 columns.
END
}
attr {
name: "dtype"
description: <<END
The `dtype` of the serialized `SparseTensor` object.
END
}
summary: "Deserialize `SparseTensor` from a (serialized) string 3-vector (1-D `Tensor`)"
description: <<END
object.
END
}

View File

@ -36,6 +36,13 @@ END
name: "num_new_vocab"
description: <<END
Number of entries in the new vocab file to remap.
END
}
attr {
name: "old_vocab_size"
description: <<END
Number of entries in the old vocab file to consider. If -1,
use the entire old vocabulary.
END
}
summary: "Given a path to new and old vocabulary files, returns a remapping Tensor of"
@ -43,7 +50,11 @@ END
length `num_new_vocab`, where `remapping[i]` contains the row number in the old
vocabulary that corresponds to row `i` in the new vocabulary (starting at line
`new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
in the new vocabulary is not in the old vocabulary. `num_vocab_offset` enables
in the new vocabulary is not in the old vocabulary. The old vocabulary is
constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
default value of -1.
`num_vocab_offset` enables
use in the partitioned variable case, and should generally be set through
examining partitioning info. The format of the files should be a text file,
with each line containing a single entity within the vocabulary.

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "IteratorSetStatsAggregator"
summary: "Associates the given iterator with the given statistics aggregator."
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "LatencyStatsDataset"
summary: "Records the latency of producing `input_dataset` elements in a StatsAggregator."
}

View File

@ -0,0 +1,32 @@
op {
graph_op_name: "MatrixExponential"
in_arg {
name: "input"
description: <<END
Shape is `[..., M, M]`.
END
}
out_arg {
name: "output"
description: <<END
Shape is `[..., M, M]`.
@compatibility(scipy)
Equivalent to scipy.linalg.expm
@end_compatibility
END
}
summary: "Computes the matrix exponential of one or more square matrices:"
description: <<END
exp(A) = \sum_{n=0}^\infty A^n/n!
The exponential is computed using a combination of the scaling and squaring
method and the Pade approximation. Details can be founds in:
Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
form square matrices. The output is a tensor of the same shape as the input
containing the exponential for all input submatrices `[..., :, :]`.
END
}

View File

@ -26,7 +26,7 @@ When set to True, find the nth-largest value in the vector and vice
versa.
END
}
summary: "Finds values of the `n`-th order statistic for the last dmension."
summary: "Finds values of the `n`-th order statistic for the last dimension."
description: <<END
If the input is a vector (rank-1), finds the entries which is the nth-smallest
value in the vector and outputs their values as scalar tensor.

View File

@ -0,0 +1,59 @@
op {
graph_op_name: "ResourceApplyAddSign"
in_arg {
name: "var"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "m"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "lr"
description: <<END
Scaling factor. Must be a scalar.
END
}
in_arg {
name: "alpha"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "sign_decay"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "beta"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "grad"
description: <<END
The gradient.
END
}
attr {
name: "use_locking"
description: <<END
If `True`, updating of the var and m tensors is
protected by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
END
}
summary: "Update \'*var\' according to the AddSign update."
description: <<END
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
update <- (alpha + sign_decay * sign(g) *sign(m)) * g
variable <- variable - lr_t * update
END
}

View File

@ -0,0 +1,59 @@
op {
graph_op_name: "ResourceApplyPowerSign"
in_arg {
name: "var"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "m"
description: <<END
Should be from a Variable().
END
}
in_arg {
name: "lr"
description: <<END
Scaling factor. Must be a scalar.
END
}
in_arg {
name: "logbase"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "sign_decay"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "beta"
description: <<END
Must be a scalar.
END
}
in_arg {
name: "grad"
description: <<END
The gradient.
END
}
attr {
name: "use_locking"
description: <<END
If `True`, updating of the var and m tensors is
protected by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
END
}
summary: "Update \'*var\' according to the AddSign update."
description: <<END
m_t <- beta1 * m_{t-1} + (1 - beta1) * g
update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g
variable <- variable - lr_t * update
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "StatsAggregatorHandle"
summary: "Creates a statistics manager resource."
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "StatsAggregatorSummary"
summary: "Produces a summary of any statistics recorded by the given statistics manager."
}

View File

@ -48,6 +48,17 @@ END
If true (default), Tensors in the TensorArray are cleared
after being read. This disables multiple read semantics but allows early
release of memory.
END
}
attr {
name: "identical_element_shapes"
description: <<END
If true (default is false), then all
elements in the TensorArray will be expected to have have identical shapes.
This allows certain behaviors, like dynamically checking for
consistent shapes on write, and being able to fill in properly
shaped zero tensors on stack -- even if the element_shape attribute
is not fully defined.
END
}
attr {

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "DeserializeSparse"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "MatrixExponential"
visibility: HIDDEN
}

View File

@ -281,6 +281,9 @@ static void StringReplace(const string& from, const string& to, string* s) {
} else {
split.push_back(s->substr(pos, found - pos));
pos = found + from.size();
if (pos == s->size()) { // handle case where `from` is at the very end.
split.push_back("");
}
}
}
// Join the pieces back together with a new delimiter.
@ -316,6 +319,36 @@ static void RenameInDocs(const string& from, const string& to, OpDef* op_def) {
}
}
static void RenameInDocs(const string& from, const string& to,
ApiDef* api_def) {
const string from_quoted = strings::StrCat("`", from, "`");
const string to_quoted = strings::StrCat("`", to, "`");
for (int i = 0; i < api_def->in_arg_size(); ++i) {
if (!api_def->in_arg(i).description().empty()) {
StringReplace(from_quoted, to_quoted,
api_def->mutable_in_arg(i)->mutable_description());
}
}
for (int i = 0; i < api_def->out_arg_size(); ++i) {
if (!api_def->out_arg(i).description().empty()) {
StringReplace(from_quoted, to_quoted,
api_def->mutable_out_arg(i)->mutable_description());
}
}
for (int i = 0; i < api_def->attr_size(); ++i) {
if (!api_def->attr(i).description().empty()) {
StringReplace(from_quoted, to_quoted,
api_def->mutable_attr(i)->mutable_description());
}
}
if (!api_def->summary().empty()) {
StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
}
if (!api_def->description().empty()) {
StringReplace(from_quoted, to_quoted, api_def->mutable_description());
}
}
const OpGenOverride* OpGenOverrideMap::ApplyOverride(OpDef* op_def) const {
// Look up
const auto iter = map_.find(op_def->name());
@ -521,6 +554,7 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
". All elements in arg_order override must match base arg_order: ",
str_util::Join(base_api_def->arg_order(), ", "));
}
base_api_def->clear_arg_order();
std::copy(
new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
@ -608,6 +642,32 @@ Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
return Status::OK();
}
void ApiDefMap::UpdateDocs() {
for (auto& name_and_api_def : map_) {
auto& api_def = name_and_api_def.second;
CHECK_GT(api_def.endpoint_size(), 0);
const string canonical_name = api_def.endpoint(0).name();
if (api_def.graph_op_name() != canonical_name) {
RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
}
for (const auto& in_arg : api_def.in_arg()) {
if (in_arg.name() != in_arg.rename_to()) {
RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
}
}
for (const auto& out_arg : api_def.out_arg()) {
if (out_arg.name() != out_arg.rename_to()) {
RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
}
}
for (const auto& attr : api_def.attr()) {
if (attr.name() != attr.rename_to()) {
RenameInDocs(attr.name(), attr.rename_to(), &api_def);
}
}
}
}
const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
return gtl::FindOrNull(map_, name);
}

View File

@ -106,6 +106,12 @@ class ApiDefMap {
// passed to the constructor.
Status LoadApiDef(const string& api_def_file_contents);
// Updates ApiDef docs. For example, if ApiDef renames an argument
// or attribute, applies these renames to descriptions as well.
// UpdateDocs should only be called once after all ApiDefs are loaded
// since it replaces original op names.
void UpdateDocs();
// Look up ApiDef proto based on the given graph op name.
// If graph op name is not in this ApiDefMap, returns nullptr.
//

View File

@ -455,5 +455,62 @@ op {
status = api_map.LoadApiDef(api_def3);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
}
TEST(OpGenLibTest, ApiDefUpdateDocs) {
const string op_list1 = R"(op {
name: "testop"
input_arg {
name: "arg_a"
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
output_arg {
name: "arg_c"
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
attr {
name: "attr_a"
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
)";
const string api_def1 = R"(
op {
graph_op_name: "testop"
endpoint {
name: "testop2"
}
in_arg {
name: "arg_a"
rename_to: "arg_aa"
}
out_arg {
name: "arg_c"
rename_to: "arg_cc"
description: "New description: `arg_a`, `arg_c`, `attr_a`, `testop`"
}
attr {
name: "attr_a"
rename_to: "attr_aa"
}
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(op_list1, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
api_map.UpdateDocs();
const string expected_description =
"`arg_aa`, `arg_cc`, `attr_aa`, `testop2`";
EXPECT_EQ(expected_description, api_map.GetApiDef("testop")->description());
EXPECT_EQ(expected_description,
api_map.GetApiDef("testop")->in_arg(0).description());
EXPECT_EQ("New description: " + expected_description,
api_map.GetApiDef("testop")->out_arg(0).description());
EXPECT_EQ(expected_description,
api_map.GetApiDef("testop")->attr(0).description());
}
} // namespace
} // namespace tensorflow

View File

@ -316,7 +316,9 @@ def tf_gen_op_wrapper_cc(name,
op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
deps=None,
override_file=None,
include_internal_ops=0):
include_internal_ops=0,
# ApiDefs will be loaded in the order specified in this list.
api_def_srcs=[]):
# Construct an op generator binary for these ops.
tool = out_ops_file + "_gen_cc"
if deps == None:
@ -328,12 +330,26 @@ def tf_gen_op_wrapper_cc(name,
linkstatic=1, # Faster to link this one-time-use binary dynamically
deps=[op_gen] + deps)
srcs = api_def_srcs[:]
if override_file == None:
srcs = []
override_arg = ","
else:
srcs = [override_file]
srcs += [override_file]
override_arg = "$(location " + override_file + ")"
if not api_def_srcs:
api_def_args_str = ","
else:
api_def_args = []
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
" $$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
native.genrule(
name=name + "_genrule",
outs=[
@ -344,7 +360,7 @@ def tf_gen_op_wrapper_cc(name,
tools=[":" + tool] + tf_binary_additional_srcs(),
cmd=("$(location :" + tool + ") $(location :" + out_ops_file + ".h) " +
"$(location :" + out_ops_file + ".cc) " + override_arg + " " +
str(include_internal_ops)))
str(include_internal_ops) + " " + api_def_args_str))
# Given a list of "op_lib_names" (a list of files in the ops directory
@ -387,7 +403,9 @@ def tf_gen_op_wrappers_cc(name,
op_gen=clean_dep("//tensorflow/cc:cc_op_gen_main"),
override_file=None,
include_internal_ops=0,
visibility=None):
visibility=None,
# ApiDefs will be loaded in the order apecified in this list.
api_def_srcs=[]):
subsrcs = other_srcs[:]
subhdrs = other_hdrs[:]
internalsrcs = []
@ -399,7 +417,8 @@ def tf_gen_op_wrappers_cc(name,
pkg=pkg,
op_gen=op_gen,
override_file=override_file,
include_internal_ops=include_internal_ops)
include_internal_ops=include_internal_ops,
api_def_srcs=api_def_srcs)
subsrcs += ["ops/" + n + ".cc"]
subhdrs += ["ops/" + n + ".h"]
internalsrcs += ["ops/" + n + "_internal.cc"]